-
Notifications
You must be signed in to change notification settings - Fork 21
/
solubility_comparison.py
executable file
·131 lines (100 loc) · 4.5 KB
/
solubility_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
import sys
import os
import deepchem
from deepchem.models import GraphConvModel, WeaveModel
from deepchem.models.sklearn_models import RandomForestRegressor, SklearnModel
import pandas as pd
from rdkit import Chem
import itertools
from esol import ESOLCalculator
# Turn off TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# -------- Utility Functions -----------------------------------#
def featurize_data(tasks, featurizer, normalize, dataset_file):
loader = deepchem.data.CSVLoader(tasks=tasks, smiles_field="SMILES", featurizer=featurizer)
dataset = loader.featurize(dataset_file, shard_size=8192)
move_mean = True
if normalize:
transformers = [deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=move_mean)]
else:
transformers = []
for transformer in transformers:
dataset = transformer.transform(dataset)
return dataset, featurizer, transformers
def generate_prediction(input_file_name, model, featurizer, transformers):
df = pd.read_csv(input_file_name)
mol_list = [Chem.MolFromSmiles(x) for x in df.SMILES]
val_feats = featurizer.featurize(mol_list)
res = model.predict_on_batch(val_feats, transformers)
# kind of a hack
# seems like some models return a list of lists and others (e.g. RF) return a list
# check to see if the first element in the returned array is a list, if so, flatten the list
if type(res[0]) is list:
df["pred_vals"] = list(itertools.chain.from_iterable(*res))
else:
df["pred_vals"] = res
return df
# ----------- Model Generator Functions --------------------------#
def generate_graph_conv_model():
batch_size = 128
model = GraphConvModel(1, batch_size=batch_size, mode='regression')
return model
def generate_weave_model():
batch_size = 64
model = WeaveModel(1, batch_size=batch_size, learning_rate=1e-3, use_queue=False, mode='regression')
return model
def generate_rf_model():
model_dir = "."
sklearn_model = RandomForestRegressor(n_estimators=500)
return SklearnModel(sklearn_model, model_dir)
# ---------------- Function to Run Models ----------------------#
def run_model(model_func, task_list, featurizer, normalize, training_file_name, validation_file_name, nb_epoch):
dataset, featurizer, transformers = featurize_data(task_list, featurizer, normalize, training_file_name)
model = model_func()
if nb_epoch > 0:
model.fit(dataset, nb_epoch)
else:
model.fit(dataset)
pred_df = generate_prediction(validation_file_name, model, featurizer, transformers)
return pred_df
# ------------------ Function to Calculate ESOL ----------------------*
def calc_esol(input_file_name, smiles_col="SMILES"):
df = pd.read_csv(input_file_name)
esol_calculator = ESOLCalculator()
res = []
for smi in df[smiles_col].values:
mol = Chem.MolFromSmiles(smi)
res.append(esol_calculator.calc_esol(mol))
df["pred_vals"] = res
return df
# ----------------- main ---------------------------------------------*
def main():
training_file_name = "delaney.csv"
validation_file_name = "dls_100_unique.csv"
output_file_name = "solubility_comparison.csv"
task_list = ['measured log(solubility:mol/L)']
print("=====ESOL=====")
esol_df = calc_esol(validation_file_name)
print("=====Random Forest=====")
featurizer = deepchem.feat.fingerprints.CircularFingerprint(size=1024)
model_func = generate_rf_model
rf_df = run_model(model_func, task_list, featurizer, False, training_file_name, validation_file_name, nb_epoch=-1)
print("=====Weave======")
featurizer = deepchem.feat.WeaveFeaturizer()
model_func = generate_weave_model
weave_df = run_model(model_func, task_list, featurizer, True, training_file_name, validation_file_name, nb_epoch=30)
print("=====Graph Convolution=====")
featurizer = deepchem.feat.ConvMolFeaturizer()
model_func = generate_graph_conv_model
gc_df = run_model(model_func, task_list, featurizer, True, training_file_name, validation_file_name, nb_epoch=20)
output_df = pd.DataFrame(rf_df[["SMILES", "Chemical name", "LogS exp (mol/L)"]])
output_df["ESOL"] = esol_df["pred_vals"]
output_df["RF"] = rf_df["pred_vals"]
output_df["Weave"] = weave_df["pred_vals"]
output_df["GC"] = gc_df["pred_vals"]
output_df.to_csv(output_file_name, index=False, float_format="%0.2f")
print("wrote results to", output_file_name)
if __name__ == "__main__":
main()