Skip to content

Commit

Permalink
fix(pocket_alignment): in place modification, offline setup, and edge…
Browse files Browse the repository at this point in the history
… index renumbering #103

- Had to make some modifications since edge index needs to be updated after applying the mask so that it still points to the right nodes and we dont get something like an "IndexError" for being out of bounds

- Also error due to not removing all proteins without pocket sequences (line 216 saved the old dataset instead of the new one).

- Successfully built pocket datasets for davis and kiba

#131 #103
  • Loading branch information
jyaacoub committed Aug 7, 2024
1 parent a8dce15 commit c163778
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 54 deletions.
112 changes: 83 additions & 29 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,37 @@
# # %%
# import numpy as np
# import torch

# d = torch.load("/cluster/home/t122995uhn/projects/data/v131/DavisKibaDataset/davis/nomsa_aflow_original_binary/full/data_pro.pt")
# np.array(list(d['ABL1(F317I)p'].pro_seq))[d['ABL1(F317I)p'].pocket_mask].shape



# %%
# building pocket datasets:
from src.utils.pocket_alignment import pocket_dataset_full
import shutil
import os

data_dir = '/cluster/home/t122995uhn/projects/data/'
db_type = ['kiba', 'davis']
db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary',
'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary']

for t in db_type:
for f in db_feat:
print(f'\n---{t}-{f}---\n')
dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full"
save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full"

pocket_dataset_full(
dataset_dir= dataset_dir,
pocket_dir = f"{data_dir}/{t}/",
save_dir = save_dir,
skip_download=True
)


#%%
import pandas as pd

Expand Down Expand Up @@ -37,45 +71,65 @@ def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/dat


#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
##############################################################################
########################## BUILD/SPLIT DATASETS ##############################
##############################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba],
dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba]
splits = ['davis', 'kiba']
splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits]
print(splits)

#%%
for split, db in zip(splits, dbs):
print('\n',split, db)
create_datasets(db,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],)
# data_root=os.path.abspath('../data/test/'))
test_prots_csv=f'{split}/test.csv',
val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)])

# %% Copy splits to commit them:
#from to:
import shutil
from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/'
to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/'
from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis']
to_db = ['pdbbind', 'kiba', 'davis']

from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db]
to_db = [f'{to_dir_p}/{f}' for f in to_db]

for src, dst in zip(from_db, to_db):
for x in ['train', 'val']:
for i in range(5):
print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")
shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")

print(f"{src}/test/XY.csv", f"{dst}/test.csv")
shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv")


#%% TEST INFERENCE
from src import cfg
from src.utils.loader import Loader

# db2 = Loader.load_dataset(cfg.DATA_OPT.davis,
# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
# path='/cluster/home/t122995uhn/projects/data/',
# subset="full")

db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis,
cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
path='/cluster/home/t122995uhn/projects/data/v131',
training_fold=0,
batch_train=2)
for b2 in db2['test']: break


# %%
m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
dropout=0.3480, output_dim=256,
)

#%%
# m(b['protein'], b['ligand'])
m(b2['protein'], b2['ligand'])
#%%
model = m
loaders = db2
device = 'cpu'
NUM_EPOCHS = 1
LEARNING_RATE = 0.001
from src.train_test.training import train

logs = train(model, loaders['train'], loaders['val'], device,
epochs=NUM_EPOCHS, lr_0=LEARNING_RATE)
# %%
67 changes: 43 additions & 24 deletions src/utils/pocket_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from Bio import Align
from Bio.Align import substitution_matrices
import numpy as np
import pandas as pd
import torch

Expand Down Expand Up @@ -78,26 +79,34 @@ def mask_graph(data, mask: list[bool]):
additional attributes:
-pocket_mask : list[bool]
The mask specified by the mask parameter of dimension [full_seuqence_length]
-pocket_mask_x : torch.Tensor
-x : torch.Tensor
The nodes of only the pocket of the protein sequence of dimension
[pocket_sequence_length, num_features]
-pocket_mask_edge_index : torch.Tensor
-edge_index : torch.Tensor
The edge connections in COO format only relating to
the pocket nodes of the protein sequence of dimension [2, num_pocket_edges]
"""
# node map for updating edge indicies after mask
node_map = np.cumsum(mask) - 1

nodes = data.x[mask]
edges = data.edge_index
edges = []
edge_mask = []
for i in range(edges.shape[1]):
# Throw out edges that are connected to at least one node not in the
# binding pocket
node_1, node_2 = edges[:,i][0], edges[:,i][1]
edge_mask.append(True) if mask[node_1] and mask[node_2] else edge_mask.append(False)
edges = torch.transpose(torch.transpose(edges, 0, 1)[edge_mask], 0, 1)
for i in range(data.edge_index.shape[1]):
# Throw out edges that are not part of connecting two nodes in the pocket...
node_1, node_2 = data.edge_index[:,i][0], data.edge_index[:,i][1]
if mask[node_1] and mask[node_2]:
# append mapped index:
edges.append([node_map[node_1], node_map[node_2]])
edge_mask.append(True)
else:
edge_mask.append(False)

data.x = nodes
data.pocket_mask = mask
data.pocket_mask_x = nodes
data.pocket_mask_edge_index = edges
data.edge_index = torch.tensor(edges).T # reshape to (2, E)
if 'edge_weight' in data:
data.edge_weight = data.edge_weight[edge_mask]
return data


Expand All @@ -122,7 +131,8 @@ def _parse_json(json_path: str) -> str:

def get_dataset_binding_pockets(
dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full',
pockets_path: str = 'data/DavisKibaDataset/kiba_pocket'
pockets_path: str = 'data/DavisKibaDataset/kiba_pocket',
skip_download: bool = False,
) -> tuple[dict[str, str], set[str]]:
"""
Get all binding pocket sequences for a dataset
Expand All @@ -149,21 +159,27 @@ def get_dataset_binding_pockets(
# Strip out mutations and '-(alpha, beta, gamma)' tags if they are present,
# the binding pocket sequence will be the same for mutated and non-mutated genes
prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids]
dl = Downloader()
seq_save_dir = os.path.join(pockets_path, 'pockets')
os.makedirs(seq_save_dir, exist_ok=True)
download_check = dl.download_pocket_seq(prot_ids, seq_save_dir)

if not skip_download: # to use cached downloads only! (useful when on compute node)
dl = Downloader()
os.makedirs(seq_save_dir, exist_ok=True)
dl.download_pocket_seq(prot_ids, seq_save_dir)

download_errors = set()
for key, val in download_check.items():
if val == 400:
download_errors.add(key)
sequences = {}
for file in os.listdir(seq_save_dir):
pocket_seq = _parse_json(os.path.join(seq_save_dir, file))
if pocket_seq == 0 or len(pocket_seq) == 0:
download_errors.add(file.split('.')[0])
else:
sequences[file.split('.')[0]] = pocket_seq

# adding any remainder prots not downloaded.
for p in prot_ids:
if p not in sequences:
download_errors.add(p)

return (sequences, download_errors)


Expand Down Expand Up @@ -197,7 +213,7 @@ def create_binding_pocket_dataset(
new_data = mask_graph(data, mask)
new_dataset[id] = new_data
os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True)
torch.save(dataset, new_dataset_path)
torch.save(new_dataset, new_dataset_path)


def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str):
Expand All @@ -215,16 +231,17 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_
csv_save_path : str
The path to save the new CSV file to.
"""
df = pd.read_csv(dataset_csv_path)
df = df[~df['prot_id'].isin(download_errors)]
df = pd.read_csv(dataset_csv_path, index_col=0)
df = df[~df.prot_id.str.split('(').str[0].str.split('-').str[0].isin(download_errors)]
os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)
df.to_csv(csv_save_path)


def pocket_dataset_full(
dataset_dir: str,
pocket_dir: str,
save_dir: str
save_dir: str,
skip_download: bool = False
) -> None:
"""
Create all elements of a dataset that includes binding pockets. This
Expand All @@ -240,7 +257,7 @@ def pocket_dataset_full(
save_dir : str
The path to where the new dataset is to be saved
"""
pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir)
pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir, skip_download)
print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:')
print(','.join(list(download_errors)))
create_binding_pocket_dataset(
Expand All @@ -254,7 +271,9 @@ def pocket_dataset_full(
download_errors,
os.path.join(save_dir, 'cleaned_XY.csv')
)
shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt'))
if dataset_dir != save_dir:
shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt'))
shutil.copy2(os.path.join(dataset_dir, 'XY.csv'), os.path.join(save_dir, 'XY.csv'))


if __name__ == '__main__':
Expand Down
17 changes: 16 additions & 1 deletion train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,22 @@
from src.utils.arg_parse import parse_train_test_args

args, unknown_args = parse_train_test_args(verbose=True,
jyp_args='-m DG -d PDBbind -f nomsa -e binary -bs 64')
jyp_args='--model_opt DG \
--data_opt davis \
\
--feature_opt nomsa \
--edge_opt binary \
--ligand_feature_opt original \
--ligand_edge_opt binary \
\
--learning_rate 0.00012 \
--batch_size 128 \
--dropout 0.24 \
--output_dim 128 \
\
--train \
--fold_selection 0 \
--num_epochs 2000')
FORCE_TRAINING = args.train
DEBUG = args.debug

Expand Down

0 comments on commit c163778

Please sign in to comment.