Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
asarigun committed Oct 27, 2023
1 parent abfba51 commit 6e4a0d1
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 116 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ We provide the implementation of the DrugGEN, along with scripts from PyTorch Ge
- ```samples``` folder. Molecule samples will be saved in this folder.
- ```inference``` folder. Molecules generated in inference mode will be saved in this folder.

**Python scripts are:**
**Python scripts:**

- ```layers.py``` file contains **Transformer Encoder** and **Transformer Decoder** implementations.
- ```layers.py``` contains **transformer encoder** and **transformer decoder** implementations.
- ```main.py``` contains arguments and this file is used to run the model.
- ```models.py``` has the implementation of the **Generators** and **Discriminators** which are used in GAN1 and GAN2.
- ```new_dataloader.py``` constructs the graph dataset from given raw data. Uses PyG based data classes.
- ```trainer.py``` is the training and testing file for the model. Workflow is constructed in this file.
- ```utils.py``` contains performance metrics from several other papers and some unique implementations. (De Cao et al, 2018; Polykovskiy et al., 2020)
- ```utils.py``` contains performance metrics from several other papers and some unique implementations. (De Cao et al, 2018; Polykovskiy et al., 2020)

## Datasets
Three different data types (i.e., compound, protein, and bioactivity) were retrieved from various data sources to train our deep generative models. GAN1 module requires only compound data while GAN2 requires all of three data types including compound, protein, and bioactivity.
Expand Down Expand Up @@ -160,7 +160,7 @@ bash dataset_download.sh

# DrugGEN can be trained with a one-liner

python DrugGEN/main.py --mode="train" --device="cuda" --raw_file="DrugGEN/data/chembl_smiles.smi" --dataset_file="chembl45.pt" -- drug_raw_file="drug_smies.smi" --drug_dataset_file="drugs.pt" --max_atom=45
python DrugGEN/main.py --submodel="CrossLoss" --mode="train" --raw_file="DrugGEN/data/chembl_train.smi" --dataset_file="chembl45_train.pt" --drug_raw_file="DrugGEN/data/akt_train.smi" --drug_dataset_file="drugs_train.pt" --max_atom=45
```

** Please find the arguments in the **main.py** file. Explanation of the commands can be found below.
Expand All @@ -180,7 +180,8 @@ Model arguments:
--mlp_ratio MLP_RATIO MLP ratio for the Transformers
--dis_select DIS_SELECT Select the discriminator for the first and second GAN
--init_type INIT_TYPE Initialization type for the model
--dropout DROPOUT Dropout rate for the model
--dropout DROPOUT Dropout rate for the encoder
--dec_dropout DEC_DROPOUT Dropout rate for the decoder
Training arguments:
--batch_size BATCH_SIZE Batch size for the training
--epoch EPOCH Epoch number for Training
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(config):
trainer = Trainer(config)

if config.mode == 'train':
trainer.train(config)
trainer.train()
elif config.mode == 'inference':
trainer.inference()

Expand Down
74 changes: 46 additions & 28 deletions new_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from torch_geometric.data import (Data, InMemoryDataset)
import os.path as osp
from tqdm import tqdm
import re
from rdkit import RDLogger
import re
from rdkit import RDLogger
import pandas as pd

RDLogger.DisableLog('rdApp.*')
class DruggenDataset(InMemoryDataset):

Expand Down Expand Up @@ -38,18 +40,47 @@ def processed_file_names(self):
def _generate_encoders_decoders(self, data):

self.data = data
print('Creating atoms encoder and decoder..')
atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
print('Creating atoms and bonds encoder and decoder..')

atom_labels = set()
bond_labels = set()
max_length = 0
smiles_list = []
for smiles in tqdm(data):
mol = Chem.MolFromSmiles(smiles)
molecule_size = mol.GetNumAtoms()
if molecule_size > self.max_atom:
continue
smiles_list.append(smiles)
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
max_length = max(max_length, molecule_size)
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])

atom_labels.update([0]) # add PAD symbol (for unknown atoms)
atom_labels = sorted(atom_labels) # turn set into list and sort it

bond_labels = sorted(bond_labels)
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels

# atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
self.atom_num_types = len(atom_labels)
print('Created atoms encoder and decoder with {} atom types and 1 PAD symbol!'.format(
self.atom_num_types - 1))
print("atom_labels", atom_labels)
print('Creating bonds encoder and decoder..')
bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
for mol in self.data
for bond in mol.GetBonds())))
# print('Creating bonds encoder and decoder..')
# bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
# for mol in self.data
# for bond in mol.GetBonds())))
# bond_labels = [
# Chem.rdchem.BondType.ZERO,
# Chem.rdchem.BondType.SINGLE,
# Chem.rdchem.BondType.DOUBLE,
# Chem.rdchem.BondType.TRIPLE,
# Chem.rdchem.BondType.AROMATIC,
# ]

print("bond labels", bond_labels)
self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
Expand All @@ -72,7 +103,7 @@ def _generate_encoders_decoders(self, data):
with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
pickle.dump(self.bond_decoder_m,bond_decoders)


return max_length, smiles_list # data is filtered now

def _genA(self, mol, connected=True, max_length=None):

Expand Down Expand Up @@ -235,25 +266,14 @@ def label2onehot(self, labels, dim):
return out.float()

def process(self, size= None):

mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]

mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
mols = mols[:size]
indices = range(len(mols))

self._generate_encoders_decoders(mols)



pbar = tqdm(total=len(indices))
pbar.set_description(f'Processing chembl dataset')
max_length = max(mol.GetNumAtoms() for mol in mols)
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
max_length, smiles_list = self._generate_encoders_decoders(smiles_list)

data_list = []

self.m_dim = len(self.atom_decoder_m)
for idx in indices:
mol = mols[idx]
for smiles in tqdm(smiles_list, desc='Processing chembl dataset', total=len(smiles_list)):
mol = Chem.MolFromSmiles(smiles)
A = self._genA(mol, connected=True, max_length=max_length)
if A is not None:

Expand All @@ -270,7 +290,7 @@ def process(self, size= None):
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)

if self.pre_filter is not None and not self.pre_filter(data):
continue
Expand All @@ -279,9 +299,7 @@ def process(self, size= None):
data = self.pre_transform(data)

data_list.append(data)
pbar.update(1)

pbar.close()

torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))

Expand Down
100 changes: 74 additions & 26 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
torch.set_num_threads(5)
RDLogger.DisableLog('rdApp.*')
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
from training_data import load_data
from training_data import generate_z_values, load_molecules
import random
from tqdm import tqdm

Expand Down Expand Up @@ -567,19 +567,42 @@ def train(self):

# Preprocess both dataset

bulk_data = load_data(data,
drugs,
self.batch_size,
self.device,
self.b_dim,
self.m_dim,
self.drugs_b_dim,
self.drugs_m_dim,
self.z_dim,
self.vertexes)
# bulk_data = load_data(data,
# drugs,
# self.batch_size,
# self.device,
# self.b_dim,
# self.m_dim,
# self.drugs_b_dim,
# self.drugs_m_dim,
# self.z_dim,
# self.vertexes)

drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
# drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data

z, z_edge, z_node = generate_z_values(
batch_size=self.batch_size,
z_dim=self.z_dim,
vertexes=self.vertexes,
device=self.device,
)

real_graphs, a_tensor, x_tensor = load_molecules(
data=data,
batch_size=self.batch_size,
device=self.device,
b_dim=self.b_dim,
m_dim=self.m_dim,
)

drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
data=drugs,
batch_size=self.batch_size,
device=self.device,
b_dim=self.drugs_b_dim,
m_dim=self.drugs_m_dim,
)

if self.submodel == "CrossLoss":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
Expand Down Expand Up @@ -700,11 +723,13 @@ def train(self):


if (i+1) % self.log_step == 0:

logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory)
if self.submodel == "CrossLoss":
logging(self.log_path, self.start_time, fake_mol, drug_smiles, i, idx, loss, 1, self.sample_directory)
else:
logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1, self.sample_directory)
mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i)
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory)
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2, self.sample_directory)
mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i)


Expand Down Expand Up @@ -795,19 +820,42 @@ def inference(self):

# Preprocess both dataset

bulk_data = load_data(data,
drugs,
self.inf_batch_size,
self.device,
self.b_dim,
self.m_dim,
self.drugs_b_dim,
self.drugs_m_dim,
self.z_dim,
self.vertexes)
# bulk_data = load_data(data,
# drugs,
# self.inf_batch_size,
# self.device,
# self.b_dim,
# self.m_dim,
# self.drugs_b_dim,
# self.drugs_m_dim,
# self.z_dim,
# self.vertexes)

drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
# drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data

z, z_edge, z_node = generate_z_values(
batch_size=self.batch_size,
z_dim=self.z_dim,
vertexes=self.vertexes,
device=self.device,
)

real_graphs, a_tensor, x_tensor = load_molecules(
data=data,
batch_size=self.batch_size,
device=self.device,
b_dim=self.b_dim,
m_dim=self.m_dim,
)

drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
data=drugs,
batch_size=self.batch_size,
device=self.device,
b_dim=self.drugs_b_dim,
m_dim=self.drugs_m_dim,
)

if self.submodel == "CrossLoss":
GAN1_input_e = a_tensor
GAN1_input_x = x_tensor
Expand Down
Loading

0 comments on commit 6e4a0d1

Please sign in to comment.