diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index c1ea20a3..43c34eb2 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -231,6 +231,10 @@ def ensure_directed_line_graph_compatibility( if line_graph.number_of_nodes() > sum(valid_edges): valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol + # check again and raise if invalid + if line_graph.number_of_nodes() > sum(valid_edges): + raise RuntimeError("Line graph is not compatible with graph.") + edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] line_graph.ndata["edge_ids"] = edge_ids diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index d175a868..71cb3fae 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -260,5 +260,5 @@ def test_ensure_directed_line_graph_compat(graph_data, request): tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids) tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError): ensure_directed_line_graph_compatibility(g, line_graph, 1.0)