-
Notifications
You must be signed in to change notification settings - Fork 1
/
ApplyTempl.py
216 lines (205 loc) · 8.44 KB
/
ApplyTempl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# %load ./ApplyTempl.py
import copy
import itertools
import pandas as pd
from GenTempl import clear_isotope
from MainFunctions import initray, molfromsmiles
import modin.pandas as mpd
from rdkit import Chem
from rdkit.Chem import rdChemReactions
def applytemplate(analoguerxnstemplfilt, inputquery, ncpus=16, restart=True):
# breakpoint()
if ncpus > 1:
if restart:
initray(num_cpus=ncpus)
analoguerxnstemplfiltdis = mpd.DataFrame(analoguerxnstemplfilt)
else:
analoguerxnstemplfiltdis = analoguerxnstemplfilt
impurities = analoguerxnstemplfiltdis.apply(
apply_template_row, inputquery=inputquery, axis=1, result_type="reduce"
)
# Optional convert modin back to pandas
impurities = pd.Series(data=impurities.values, index=impurities.index)
impuritiesdf = pd.DataFrame(
data=impurities.tolist(),
index=impurities.index,
columns=["querycompds", "impurities", "impurityrxn", "msg5"],
)
analoguerxnsimp = copy.deepcopy(analoguerxnstemplfilt)
analoguerxnsimp[["querycompds", "impurities", "impurityrxn", "msg5"]] = impuritiesdf
return analoguerxnsimp
def apply_template_row(row, inputquery):
LHSdata = copy.deepcopy(row.LHSdata)
templt = copy.deepcopy(row.template)
unusedanalogue = row.unusedanalogue
return apply_template(LHSdata, templt, inputquery, unusedanalogue=unusedanalogue)
def apply_template(LHSdata, templt, inputquery, unusedanalogue=[]):
# breakpoint()
combs = []
imp_raw = []
imp_smles = []
imp_mol = []
imp_Rsmiles = []
diffreac = set()
msg5 = ""
# breakpoint()
template_rxn = rdChemReactions.ReactionFromSmarts(templt, useSmiles=False)
querycompdbin = []
for analoguecompd in LHSdata:
LHSdata_ = LHSdata[analoguecompd]
# Hydrogen will not be mapped and won't show up in reaction center
if "reacfrag" not in LHSdata_ or not LHSdata_["reacfrag"]:
if analoguecompd not in unusedanalogue:
querycompdbin += [
{LHSdata_["smiles"]} for i in range(LHSdata_["count"])
]
continue
for inst, fraginf in LHSdata_["reacfrag"].items():
# breakpoint()
querycompdset = set()
for frag, matchidx in fraginf.items():
for querycompd in LHSdata_["querycompds"][frag]:
if inputquery["species"][querycompd][frag]["count"] < len(matchidx):
# Invalid templates as one/more reactants have more functional groups reacting relative to query
return (
[],
[],
[],
"Species "
+ str(analoguecompd)
+ " has too many reacting functional groups",
)
if not querycompdset:
querycompdset = set(LHSdata_["querycompds"][frag])
else:
# There may be more than one query compound with match
querycompdset2 = querycompdset.intersection(
set(LHSdata[analoguecompd]["querycompds"][frag])
)
if not querycompdset2:
diffreac.add(analoguecompd)
querycompdset = querycompdset.union(
set(LHSdata[analoguecompd]["querycompds"][frag])
)
else:
querycompdset = querycompdset2
querycompdbin += [querycompdset]
# if querycompdset:
# querycompdbin+=[querycompdset]
# if not querycompdbin:
# return [],[],[],'Species '+str(analoguecompd)+' reacts at fragments from different query compounds'
# breakpoint()
if diffreac:
msg5 = (
"Species reacts at fragments from different query compounds: "
+ ",".join([str(rct) for rct in diffreac])
)
# Buggy if too many (keep limit at 6)
combs = list(itertools.product(*querycompdbin))
try:
# Addition of Chem.AddHs significantly impacts results
imp_raw = []
for comb in combs:
querymols = []
for query_reactant in comb:
querymol = molfromsmiles(query_reactant)
clear_isotope(querymol)
querymol = Chem.RemoveHs(querymol)
querymols.append(querymol)
imp_raw.append(template_rxn.RunReactants(querymols))
# imp_raw=[template_rxn.RunReactants([Chem.AddHs(molfromsmiles(query_reactant)) for query_reactant in comb]) for comb in combs]
except Exception as e:
if msg5:
msg5 = msg5 + ", " + "Template cannot be applied to reactants"
else:
msg5 = "Template cannot be applied to reactants"
return combs, "Error", "Error", msg5
if not imp_raw[0]:
if msg5:
msg5 = msg5 + ", " + "Template cannot be applied to reactants"
else:
msg5 = "Template cannot be applied to reactants"
return combs, "Error", "Error", msg5
# breakpoint()
imp_smles = [
tuple(tuple(Chem.MolToSmiles(imp) for imp in imp_prod) for imp_prod in comb)
for comb in imp_raw
]
imp_smles = [
{imp_prod for imp_prod in comb} for comb in imp_smles
] # Remove duplicates
combs2 = [] # Check chemical validity--only add chemically valid combinations
comb2 = set()
imp_smles2 = []
for idx, comb in enumerate(imp_smles):
comb2 = set()
for imp_prod in comb:
try:
imp_mol = {tuple(molfromsmiles(imp) for imp in imp_prod)}
comb2.add(imp_prod)
except Exception as e:
continue
if comb2:
imp_smles2 += [comb2]
combs2 += [combs[idx]]
if not imp_smles2:
if msg5:
msg5 = msg5 + ", " + "Impurities not chemically valid"
else:
msg5 = "Impurities not chemically valid"
return combs, imp_smles, "Error", msg5
imp_Rsmiles = [
set(
">>".join([".".join(combs2[idx]), ".".join(impprod)])
for impprod in imp_smles2[idx]
)
for idx in range(len(imp_smles2))
]
if not msg5:
msg5 = "Valid"
return combs2, imp_smles2, imp_Rsmiles, msg5
def removeduplicates(analoguerxnsimpfinal, ncpus=16, restart=False):
if ncpus > 1:
if restart:
initray(num_cpus=ncpus)
analoguerxnsimpfinaldis = mpd.DataFrame(analoguerxnsimpfinal)
else:
analoguerxnsimpfinaldis = analoguerxnsimpfinal
duplicatesrem = analoguerxnsimpfinaldis.apply(
removeduplicates_, axis=1, result_type="reduce"
) # remove duplicates
duplicatesrem = pd.Series(data=duplicatesrem.values, index=duplicatesrem.index)
duplicatesremdf = pd.DataFrame(
data=duplicatesrem.tolist(),
index=duplicatesrem.index,
columns=["querycompds", "impurities", "impurityrxn"],
)
analoguerxnsimpfinal[["querycompds", "impurities", "impurityrxn"]] = duplicatesremdf
return analoguerxnsimpfinal
def removeduplicates_(row):
# breakpoint()
rejidx = []
querycompds = copy.deepcopy(row.querycompds)
impurities = copy.deepcopy(row.impurities)
impurityrxns = copy.deepcopy(row.impurityrxn)
collset = copy.deepcopy(set(tuple(sorted(t)) for t in querycompds))
for idx, comb in enumerate(querycompds):
norm = tuple(sorted(comb))
if norm in collset:
collset = collset - {norm}
else:
rejidx += [idx]
if not rejidx: # No duplicates
return querycompds, impurities, impurityrxns
rejidx2 = []
for idx in rejidx:
compareset = set(tuple(sorted(t)) for t in impurities[idx])
otheridx = [i for i in range(len(impurities)) if i != idx]
otherset = set(tuple(sorted(t)) for oidx in otheridx for t in impurities[oidx])
# Duplicate entry is present in other query compound combinations
if compareset.issubset(otherset):
rejidx2 += [idx]
querycompds = [comb for idx, comb in enumerate(querycompds) if idx not in rejidx2]
impurities = [comb for idx, comb in enumerate(impurities) if idx not in rejidx2]
impurityrxns = [comb for idx, comb in enumerate(impurityrxns) if idx not in rejidx2]
return querycompds, impurities, impurityrxns