Skip to content

Commit

Permalink
Fix v-site - v-site exclusions missed (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Feb 21, 2024
1 parent 7ccc182 commit 88d1509
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 27 deletions.
2 changes: 1 addition & 1 deletion smee/tests/convertors/openff/test_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_convert_electrostatics_v_site():
assert parameter_map.exclusions.shape == (n_expected_exclusions, 2)
assert parameter_map.exclusion_scale_idxs.shape == (n_expected_exclusions, 1)

expected_exclusions = torch.tensor([[0, 1], [2, 1], [0, 2]], dtype=torch.long)
expected_exclusions = torch.tensor([[0, 1], [0, 2], [1, 2]], dtype=torch.long)
assert torch.allclose(parameter_map.exclusions, expected_exclusions)

expected_scales = torch.zeros((n_expected_exclusions, 1), dtype=torch.long)
Expand Down
11 changes: 6 additions & 5 deletions smee/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,18 @@ def test_find_exclusions_v_sites():
(0, 1): "scale_12",
(0, 2): "scale_13",
(0, 3): "scale_14",
(0, 4): "scale_12",
(0, 5): "scale_14",
(1, 2): "scale_12",
(1, 3): "scale_13",
(1, 4): "scale_12",
(1, 5): "scale_13",
(2, 3): "scale_12",
(2, 4): "scale_13",
(2, 5): "scale_12",
(4, 0): "scale_12",
(4, 1): "scale_12",
(4, 2): "scale_13",
(4, 3): "scale_14",
(5, 3): "scale_12",
(3, 4): "scale_14",
(3, 5): "scale_12",
(4, 5): "scale_14",
}


Expand Down
33 changes: 12 additions & 21 deletions smee/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def find_exclusions(
for bond in topology.bonds
)

if v_sites is not None:

for v_site_key in v_sites.keys:
v_site_idx = v_sites.key_to_idx[v_site_key]
parent_idx = v_site_key.orientation_atom_indices[0]

for neighbour_idx in graph.neighbors(parent_idx):
graph.add_edge(v_site_idx, neighbour_idx)

graph.add_edge(v_site_idx, parent_idx)

distances = dict(networkx.all_pairs_shortest_path_length(graph, cutoff=5))
distance_to_scale = {1: "scale_12", 2: "scale_13", 3: "scale_14", 4: "scale_15"}

Expand All @@ -56,27 +67,7 @@ def find_exclusions(
assert pair not in exclusions or exclusions[pair] == scale
exclusions[pair] = scale

if v_sites is None:
return exclusions

v_site_exclusions = {}

for v_site_key in v_sites.keys:
v_site_idx = v_sites.key_to_idx[v_site_key]
parent_idx = v_site_key.orientation_atom_indices[0]

v_site_exclusions[(v_site_idx, parent_idx)] = "scale_12"

for pair, scale in exclusions.items():
if parent_idx not in pair:
continue

if pair[0] == parent_idx:
v_site_exclusions[(v_site_idx, pair[1])] = scale
else:
v_site_exclusions[(pair[0], v_site_idx)] = scale

return {**exclusions, **v_site_exclusions}
return exclusions


def ones_like(size: _size, other: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 88d1509

Please sign in to comment.