Skip to content

Commit

Permalink
Fix hydrogen handling in scaffolds with explicit attachment points (#70)
Browse files Browse the repository at this point in the history
When a scaffold is provided during decoding, we assume the model is free
to attach further atoms to any atom in the scaffold, unless it contains
explicit attachment points, in which case we only allow attaching to
those atoms. However, the handling of these attachment points had a bug
regarding explicit hydrogens, which would cause the scaffold molecule to
be malformed and decoding to fail. This PR fixes this issue by carefully
updating the hydrogen count where appropriate.
  • Loading branch information
kmaziarz committed Dec 18, 2023
1 parent 2aaf7de commit e7e9e38
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
- Relax `protobuf` version requirement ([#62](https://github.com/microsoft/molecule-generation/pull/62))

### Fixed
- Fix hydrogen handling in scaffolds with explicit attachment points ([#70](https://github.com/microsoft/molecule-generation/pull/70))
- Avoid memory leaks and other `tensorflow` issues ([#68](https://github.com/microsoft/molecule-generation/pull/68))

## [0.4.0] - 2023-06-16
Expand Down
23 changes: 17 additions & 6 deletions molecule_generation/layers/moler_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,8 +1193,8 @@ def decode(
for graph_repr, init_mol, mol_id in zip(graph_representations, initial_molecules, mol_ids):
num_free_bond_slots = [0] * len(init_mol.GetAtoms())

atom_ids_to_remove = []
atom_ids_to_keep = []
atom_id_pairs_to_disconnect: List[Tuple[int, int]] = []
atom_ids_to_keep: List[int] = []

for atom in init_mol.GetAtoms():
if atom.GetAtomicNum() == 0:
Expand All @@ -1220,22 +1220,29 @@ def decode(
neighbour_idx = begin_idx if begin_idx != atom.GetIdx() else end_idx
num_free_bond_slots[neighbour_idx] += 1

atom_ids_to_remove.append(atom.GetIdx())
atom_id_pairs_to_disconnect.append((atom.GetIdx(), neighbour_idx))
else:
atom_ids_to_keep.append(atom.GetIdx())

if not atom_ids_to_remove:
init_mol_original = init_mol
if not atom_id_pairs_to_disconnect:
# No explicit attachment points, so assume we can connect anywhere.
num_free_bond_slots = None
else:
num_free_bond_slots = [num_free_bond_slots[idx] for idx in atom_ids_to_keep]
init_mol = Chem.RWMol(init_mol)

# Save the atom list to be able to extract neighbour atoms by their original id.
original_atom_list = list(init_mol.GetAtoms())

# Remove atoms starting from largest index, so that we don't have to account for
# indices shifting during removal.
for atom_idx in reversed(atom_ids_to_remove):
# indices of atoms to remove shifting due to other removals.
for atom_idx, neighbour_idx in reversed(atom_id_pairs_to_disconnect):
init_mol.RemoveAtom(atom_idx)

neighbour_atom = original_atom_list[neighbour_idx]
neighbour_atom.SetNumExplicitHs(neighbour_atom.GetNumExplicitHs() + 1)

# Determine how the scaffold atoms will get reordered when we canonicalize it, so we can
# permute `num_free_bond_slots` appropriately.
canonical_ordering = compute_canonical_atom_order(init_mol)
Expand All @@ -1245,6 +1252,10 @@ def decode(
# renumbering to `num_free_bond_slots` earlier.
init_mol = Chem.MolFromSmiles(Chem.MolToSmiles(init_mol))

if init_mol is None:
scaffold = Chem.MolToSmiles(init_mol_original)
raise ValueError(f"Scaffold {scaffold} could not be processed")

# Clear aromatic flags in the scaffold, since partial graphs during training never have
# them set (however we _do_ run `AtomIsAromaticFeatureExtractor`, it just always returns
# 0 for partial graphs during training).
Expand Down

0 comments on commit e7e9e38

Please sign in to comment.