From 4001318b52fd23177c70071f46c1242e37451d7d Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 11:24:18 -0700 Subject: [PATCH] add tolerance in lg compatibility --- matgl/graph/compute.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index ed736a65..c1ea20a3 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -212,7 +212,7 @@ def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> def ensure_directed_line_graph_compatibility( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 ) -> dgl.DGLGraph: """Ensure that line graph is compatible with graph. @@ -222,9 +222,14 @@ def ensure_directed_line_graph_compatibility( graph: atomistic graph line_graph: line graph of atomistic graph threebody_cutoff: cutoff for three-body interactions + tol: numerical tolerance for cutoff """ valid_edges = graph.edata["bond_dist"] <= threebody_cutoff - assert line_graph.number_of_nodes() <= sum(valid_edges), "line graph and graph are not compatible" + + # this means there probably is a bond that is just at the cutoff + # this should only really occur when batching graphs + if line_graph.number_of_nodes() > sum(valid_edges): + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] line_graph.ndata["edge_ids"] = edge_ids