Skip to content

Commit

Permalink
raise runtime error for incompatible graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Oct 2, 2023
1 parent 4001318 commit e32faa7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def ensure_directed_line_graph_compatibility(
if line_graph.number_of_nodes() > sum(valid_edges):
valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol

# check again and raise if invalid
if line_graph.number_of_nodes() > sum(valid_edges):
raise RuntimeError("Line graph is not compatible with graph.")

edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()]
line_graph.ndata["edge_ids"] = edge_ids

Expand Down
2 changes: 1 addition & 1 deletion tests/graph/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,5 @@ def test_ensure_directed_line_graph_compat(graph_data, request):
tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids)
tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign)

with pytest.raises(AssertionError):
with pytest.raises(RuntimeError):
ensure_directed_line_graph_compatibility(g, line_graph, 1.0)

0 comments on commit e32faa7

Please sign in to comment.