Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive decompose routine for fragment env #96

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
GraphActionType.RemoveEdgeAttr,
]
self.device = torch.device("cpu")
self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms())

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction
Expand Down Expand Up @@ -275,7 +276,11 @@ def collate(self, graphs: List[gd.Data]) -> gd.Batch:

def mol_to_graph(self, mol):
"""Convert an RDMol to a Graph"""
raise NotImplementedError()
assert type(mol) is Chem.Mol
all_matches = {}
for fragidx, frag in self.sorted_frags:
all_matches[fragidx] = mol.GetSubstructMatches(frag, uniquify=False)
return _recursive_decompose(self, mol, all_matches, {}, [], [], 9)

def graph_to_mol(self, g: Graph) -> Chem.Mol:
"""Convert a Graph to an RDKit molecule
Expand Down Expand Up @@ -331,3 +336,101 @@ def is_sane(self, g: Graph) -> bool:
if mol is None:
return False
return True


def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None):
if numiters is None:
numiters = [0]
numiters[0] += 1
if numiters[0] > 1_000:
raise ValueError("too many iterations")
if max_depth == 0 or len(a2f) == m.GetNumAtoms():
# try to make a mol, does it work?
# Did we match all the atoms?
if len(a2f) < m.GetNumAtoms():
return None
# graph is a tree, e = n - 1
if len(bonds) != len(frags) - 1:
return None
g = nx.Graph()
g.add_nodes_from(range(len(frags)))
g.add_edges_from([(i[0], i[1]) for i in bonds])
assert nx.is_connected(g), "Somehow we got here but fragments dont connect?"
for fi, f in enumerate(frags):
g.nodes[fi]["v"] = f
for a, b, stemidx_a, stemidx_b, _, _ in bonds:
g.edges[(a, b)][f"{a}_attach"] = stemidx_a
g.edges[(a, b)][f"{b}_attach"] = stemidx_b
m2 = ctx.graph_to_mol(g)
if m2.HasSubstructMatch(m) and m.HasSubstructMatch(m2):
return g
return None
for fragidx, frag in ctx.sorted_frags:
# Some fragments have symmetric versions, so we need all matches up to isomorphism!
matches = all_matches[fragidx]
for match in matches:
if any(i in a2f for i in match):
continue
# Verify that atoms actually have the same charge
if any(
frag.GetAtomWithIdx(ai).GetFormalCharge() != m.GetAtomWithIdx(bi).GetFormalCharge()
for ai, bi in enumerate(match)
):
continue
new_frag_idx = len(frags)
new_frags = frags + [fragidx]
new_a2f = {**a2f, **{i: (fi, new_frag_idx) for fi, i in enumerate(match)}}
possible_bonds = []
is_valid_match = True
# Is every atom that has a bond outside of this fragment also a stem atom?
for fi, i in enumerate(match):
for j in m.GetAtomWithIdx(i).GetNeighbors():
j = j.GetIdx()
if j in match:
continue
# There should only be single bonds between fragments
if m.GetBondBetweenAtoms(i, j).GetBondType() != Chem.BondType.SINGLE:
is_valid_match = False
break
# At this point, we know (i, j) is a single bond that goes outside the fragment
# so we check if the fragment we chose has that atom as a stem atom
if fi not in ctx.frags_stems[fragidx]:
is_valid_match = False
break
if not is_valid_match:
break
if not is_valid_match:
continue
for this_frag_stemidx, i in enumerate([match[s] for s in ctx.frags_stems[fragidx]]):
for j in m.GetAtomWithIdx(i).GetNeighbors():
j = j.GetIdx()
if j in match:
continue
if m.GetBondBetweenAtoms(i, j).GetBondType() != Chem.BondType.SINGLE:
continue
# Make sure the neighbor is part of an already identified fragment
if j in a2f and a2f[j] != new_frag_idx:
other_frag_atomidx, other_frag_idx = a2f[j]
try:
# Make sure that fragment has that atom as a stem atom
other_frag_stemidx = ctx.frags_stems[frags[other_frag_idx]].index(other_frag_atomidx)
except ValueError as e:
continue
# Make sure that that fragment's stem atom isn't already used
for b in bonds + possible_bonds:
if b[0] == other_frag_idx and b[2] == other_frag_stemidx:
break
if b[1] == other_frag_idx and b[3] == other_frag_stemidx:
break
if b[0] == new_frag_idx and b[2] == this_frag_stemidx:
break
if b[1] == new_frag_idx and b[3] == this_frag_stemidx:
break
else:
possible_bonds.append(
(other_frag_idx, new_frag_idx, other_frag_stemidx, this_frag_stemidx, i, j)
)
new_bonds = bonds + possible_bonds
dec = _recursive_decompose(ctx, m, all_matches, new_a2f, new_frags, new_bonds, max_depth - 1, numiters)
if dec:
return dec