Skip to content

Commit

Permalink
add tolerance in lg compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Oct 2, 2023
1 parent 0f97de3 commit 4001318
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) ->


def ensure_directed_line_graph_compatibility(
graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float
graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7
) -> dgl.DGLGraph:
"""Ensure that line graph is compatible with graph.
Expand All @@ -222,9 +222,14 @@ def ensure_directed_line_graph_compatibility(
graph: atomistic graph
line_graph: line graph of atomistic graph
threebody_cutoff: cutoff for three-body interactions
tol: numerical tolerance for cutoff
"""
valid_edges = graph.edata["bond_dist"] <= threebody_cutoff
assert line_graph.number_of_nodes() <= sum(valid_edges), "line graph and graph are not compatible"

# this means there probably is a bond that is just at the cutoff
# this should only really occur when batching graphs
if line_graph.number_of_nodes() > sum(valid_edges):
valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol

edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()]
line_graph.ndata["edge_ids"] = edge_ids
Expand Down

0 comments on commit 4001318

Please sign in to comment.