From 4dc1f3e5f7c3b96c5c08528d27d2d2e8f90574f2 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 31 Aug 2023 20:50:00 -0700 Subject: [PATCH 01/15] remove numpy import --- matgl/layers/_atom_ref.py | 1 - 1 file changed, 1 deletion(-) diff --git a/matgl/layers/_atom_ref.py b/matgl/layers/_atom_ref.py index 81df068e..a16a1590 100644 --- a/matgl/layers/_atom_ref.py +++ b/matgl/layers/_atom_ref.py @@ -1,7 +1,6 @@ from __future__ import annotations import dgl -import numpy as np import torch from torch import nn From c8dc182ee0209ccbbad22706529720a5fa0a7464 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 1 Sep 2023 10:01:59 -0700 Subject: [PATCH 02/15] STY: ruff --- matgl/layers/_atom_ref.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matgl/layers/_atom_ref.py b/matgl/layers/_atom_ref.py index a16a1590..81df068e 100644 --- a/matgl/layers/_atom_ref.py +++ b/matgl/layers/_atom_ref.py @@ -1,6 +1,7 @@ from __future__ import annotations import dgl +import numpy as np import torch from torch import nn From fc01698ebef5f674cd517b5c3d3733846193769e Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 31 Aug 2023 20:50:00 -0700 Subject: [PATCH 03/15] remove numpy import --- matgl/layers/_atom_ref.py | 1 - 1 file changed, 1 deletion(-) diff --git a/matgl/layers/_atom_ref.py b/matgl/layers/_atom_ref.py index 81df068e..a16a1590 100644 --- a/matgl/layers/_atom_ref.py +++ b/matgl/layers/_atom_ref.py @@ -1,7 +1,6 @@ from __future__ import annotations import dgl -import numpy as np import torch from torch import nn From 2ae5d0729da1005715683955d4ba8aa1d79d2d37 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 1 Sep 2023 10:01:59 -0700 Subject: [PATCH 04/15] STY: ruff --- matgl/layers/_atom_ref.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matgl/layers/_atom_ref.py b/matgl/layers/_atom_ref.py index a16a1590..81df068e 100644 --- a/matgl/layers/_atom_ref.py +++ b/matgl/layers/_atom_ref.py @@ -1,6 +1,7 @@ from __future__ import annotations import dgl +import numpy as np import torch from torch import nn From 00503562016c7a675db0e97c23c090d9caca0d1c Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 29 Sep 2023 16:18:30 -0700 Subject: [PATCH 05/15] prune graph function --- matgl/graph/compute.py | 56 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 2f848639..fcf730d2 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -1,6 +1,9 @@ """Computing various graph based operations.""" + from __future__ import annotations +from typing import Callable + import dgl import numpy as np import torch @@ -131,12 +134,51 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float): Returns: l_g: DGL graph containing three body information from graph """ - valid_three_body = g.edata["bond_dist"] <= threebody_cutoff - src_id_with_three_body = g.edges()[0][valid_three_body] - dst_id_with_three_body = g.edges()[1][valid_three_body] - graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body)) - graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] + graph_with_three_body = prune_edges_by_features(g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body) return l_g + + +def prune_edges_by_features( + graph: dgl.DGLGraph, + feat_name: str, + condition: Callable[[torch.Tensor], torch.Tensor], + keep_ndata: bool = False, + keep_edata: bool = True, + *args, + **kwargs, +) -> dgl.DGLGraph: + """Removes edges graph that do satisfy given condition based on a specified feature value. + + Returns a new graph with edges removed. + + Args: + graph: DGL graph + feat_name: edge field name + condition: condition function. Must be a function where the first is the value + of the edge field data and returns a Tensor of boolean values. + keep_ndata: whether to keep node features + keep_edata: whether to keep edge features + *args: additional arguments to pass to condition function + **kwargs: additional keyword arguments to pass to condition function + + Returns: dgl.Graph with removed edges. + """ + if feat_name not in graph.edata: + raise ValueError(f"Edge field {feat_name} not an edge feature in given graph.") + + valid_edges = torch.logical_not(condition(graph.edata[feat_name], *args, **kwargs)) + src, dst = graph.edges() + src, dst = src[valid_edges], dst[valid_edges] + e_ids = valid_edges.nonzero().squeeze() + new_g = dgl.graph((src, dst), device=graph.device) + new_g.edata["edge_ids"] = e_ids # keep track of original edge ids + + if keep_ndata: + for key, value in graph.ndata.items(): + new_g.ndata[key] = value + if keep_edata: + for key, value in graph.edata.items(): + new_g.edata[key] = value[valid_edges] + + return new_g From 50ea4dfabc3a1c562e2c3bebabad3bb3fd9721df Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 29 Sep 2023 16:19:12 -0700 Subject: [PATCH 06/15] directed line graph --- matgl/graph/compute.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index fcf730d2..4398b863 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -139,6 +139,72 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float): return l_g +def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph: + """Creates a line graph from a graph, considers periodic boundary conditions. + + Args: + graph: DGL graph representing atom graph + threebody_cutoff: cutoff for three-body interactions + + Returns: + line_graph: DGL graph line graph of pruned graph to three body cutoff + """ + with torch.no_grad(): + pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) + src_indices, dst_indices = pg.edges() + images = pg.edata["pbc_offset"] + all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0) + num_bonds_per_atom = torch.count_nonzero(src_indices.unsqueeze(dim=1) == all_indices, dim=0) + num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave(num_bonds_per_atom) + lg_src = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) + lg_dst = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) + + incoming_edges = src_indices.unsqueeze(1) == dst_indices + is_self_edge = src_indices == dst_indices + not_self_edge = ~is_self_edge + + n = 0 + # create line graph edges for bonds that are self edges in atom graph + if is_self_edge.any(): + edge_inds_s = is_self_edge.nonzero() + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge] + 1) + lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() + lg_src_s = lg_src_s[lg_src_s != lg_dst_s] + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge]) + n = len(lg_dst_s) + lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s + + # create line graph edges for bonds that are not self edges in atom graph + shared_src = src_indices.unsqueeze(1) == src_indices + back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all(-images.unsqueeze(1) == images, axis=2) + incoming = incoming_edges & (shared_src | ~back_tracking) + + edge_inds_ns = not_self_edge.nonzero().squeeze() + lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() + lg_dst_ns = edge_inds_ns.repeat_interleave(num_edges_per_bond[not_self_edge]) + lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns + lg = dgl.graph((lg_src, lg_dst)) + + for key in pg.edata: + lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] + + # we need to store the sign of bond vector when a bond is a src node in the line + # graph in order to appropriately calculate angles when self edges are involved + lg.ndata["src_bond_sign"] = torch.ones( + (lg.number_of_nodes(), 1), dtype=lg.ndata["bond_vec"].dtype, device=lg.device + ) + # if we flip self edges then we need to correct computed angles by pi - angle + # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] + # find the intersection for the rare cases where not all edges end up as nodes in the line graph + all_ns, counts = torch.cat([torch.arange(lg.number_of_nodes(), device=graph.device), edge_inds_ns]).unique( + return_counts=True + ) + lg_inds_ns = all_ns[torch.where(counts > 1)] + lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns] + + return lg + + def prune_edges_by_features( graph: dgl.DGLGraph, feat_name: str, From 1ba58b62ee530882a32a1b83c01a05ba351c370f Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 29 Sep 2023 16:19:49 -0700 Subject: [PATCH 07/15] directed line graph compat --- matgl/graph/compute.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 4398b863..41225b26 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -205,6 +205,37 @@ def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> return lg +def ensure_directed_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float +) -> dgl.DGLGraph: + """Ensure that line graph is compatible with graph. + + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + """ + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + assert line_graph.number_of_nodes() <= sum(valid_edges), "line graph and graph are not compatible" + + edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] + line_graph.ndata["edge_ids"] = edge_ids + + for key in graph.edata: + line_graph.ndata[key] = graph.edata[key][edge_ids] + + src_indices, dst_indices = graph.edges() + ns_edge_ids = (src_indices[edge_ids] != dst_indices[edge_ids]).nonzero().squeeze() + line_graph.ndata["src_bond_sign"] = torch.ones( + (line_graph.number_of_nodes(), 1), dtype=graph.edata["bond_vec"].dtype, device=line_graph.device + ) + line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids] + + return line_graph + + def prune_edges_by_features( graph: dgl.DGLGraph, feat_name: str, From 4dc6038c1984b28879fdd0847a2052ef78d8ad56 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 29 Sep 2023 16:26:09 -0700 Subject: [PATCH 08/15] add directed option in compute_theta --- matgl/graph/compute.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 41225b26..ed736a65 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -1,5 +1,4 @@ """Computing various graph based operations.""" - from __future__ import annotations from typing import Callable @@ -99,25 +98,32 @@ def compute_theta_and_phi(edges: dgl.udf.EdgeBatch): phi: torch.Tensor triple_bond_lengths (torch.tensor): """ - angles = compute_theta(edges, cosine=True) + angles = compute_theta(edges, cosine=True, directed=False) angles["phi"] = torch.zeros_like(angles["cos_theta"]) return angles -def compute_theta(edges: dgl.udf.EdgeBatch, cosine: bool = False) -> dict[str, torch.Tensor]: +def compute_theta( + edges: dgl.udf.EdgeBatch, cosine: bool = False, directed: bool = True, eps=1e-7 +) -> dict[str, torch.Tensor]: """User defined dgl function to calculate bond angles from edges in a graph. Args: edges: DGL graph edges cosine: Whether to return the cosine of the angle or the angle itself + directed: Whether to the line graph was created with create directed line graph. + In which case bonds (only those that are not self bonds) need to + have their bond vectors flipped. + eps: eps value used to clamp cosine values to avoid acos of values > 1.0 Returns: dict[str, torch.Tensor]: Dictionary containing bond angles and distances """ - vec1 = edges.src["bond_vec"] + vec1 = edges.src["bond_vec"] * edges.src["src_bond_sign"] if directed else edges.src["bond_vec"] vec2 = edges.dst["bond_vec"] key = "cos_theta" if cosine else "theta" val = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1)) + val = val.clamp_(min=-1 + eps, max=1 - eps) # stability for floating point numbers > 1.0 if not cosine: val = torch.acos(val) return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} From 0f97de37fb9ec6e17ade9dd90215c4d2cda0526d Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 29 Sep 2023 16:26:18 -0700 Subject: [PATCH 09/15] tests --- tests/graph/test_compute.py | 94 +++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index be0d118a..d175a868 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -3,7 +3,9 @@ from functools import partial import numpy as np +import pytest import torch +import torch.testing as tt from pymatgen.core import Lattice, Structure from matgl.ext.pymatgen import Structure2Graph, get_element_list @@ -11,7 +13,10 @@ compute_pair_vector_and_distance, compute_theta, compute_theta_and_phi, + create_directed_line_graph, create_line_graph, + ensure_directed_line_graph_compatibility, + prune_edges_by_features, ) @@ -46,8 +51,8 @@ def _calculate_cos_loop(graph, threebody_cutoff=4.0): for j in range(n_site): if i == j: continue - vi = graph.edata["bond_vec"][i + start_index].numpy() - vj = graph.edata["bond_vec"][j + start_index].numpy() + vi = graph.edata["bond_vec"][i + start_index].detach().numpy() + vj = graph.edata["bond_vec"][j + start_index].detach().numpy() di = np.linalg.norm(vi) dj = np.linalg.norm(vj) if (di <= threebody_cutoff) and (dj <= threebody_cutoff): @@ -114,14 +119,13 @@ def test_compute_angle(self, graph_Mo, graph_CH4): ) # test only compute theta - line_graph.apply_edges(compute_theta) - np.testing.assert_array_almost_equal( - np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"])), decimal=4 - ) + line_graph.apply_edges(partial(compute_theta, directed=False)) + theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7)) + np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata["theta"])), decimal=4) # test only compute theta with cosine _ = line_graph.edata.pop("cos_theta") - line_graph.apply_edges(partial(compute_theta, cosine=True)) + line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"])) ) @@ -140,14 +144,14 @@ def test_compute_angle(self, graph_Mo, graph_CH4): ) # test only compute theta - line_graph.apply_edges(compute_theta) + line_graph.apply_edges(partial(compute_theta, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"])) ) # test only compute theta with cosine _ = line_graph.edata.pop("cos_theta") - line_graph.apply_edges(partial(compute_theta, cosine=True)) + line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False)) np.testing.assert_array_almost_equal( np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"])) ) @@ -186,3 +190,75 @@ def test_line_graph_extensive(): assert 2 * g1.number_of_edges() == g2.number_of_edges() assert 2 * lg1.number_of_nodes() == lg2.number_of_nodes() assert 2 * lg1.number_of_edges() == lg2.number_of_edges() + + +@pytest.mark.parametrize("keep_ndata", [True, False]) +@pytest.mark.parametrize("keep_edata", [True, False]) +def test_remove_edges_by_features(graph_Mo, keep_ndata, keep_edata): + s1, g1, state1 = graph_Mo + bv, bd = compute_pair_vector_and_distance(g1) + g1.edata["bond_vec"] = bv + g1.edata["bond_dist"] = bd + + new_cutoff = 3.0 + converter = Structure2Graph(element_types=get_element_list([s1]), cutoff=new_cutoff) + g2, state2 = converter.get_graph(s1) + + # remove edges by features + new_g = prune_edges_by_features( + g1, "bond_dist", condition=lambda x: x > new_cutoff, keep_ndata=keep_ndata, keep_edata=keep_edata + ) + valid_edges = g1.edata["bond_dist"] <= new_cutoff + + assert new_g.num_edges() == g2.num_edges() + assert new_g.num_nodes() == g2.num_nodes() + assert torch.allclose(new_g.edata["edge_ids"], valid_edges.nonzero().squeeze()) + + if keep_ndata: + assert new_g.ndata.keys() == g1.ndata.keys() + + if keep_edata: + for key in g1.edata: + if key != "edge_ids": + assert torch.allclose(new_g.edata[key], g1.edata[key][valid_edges]) + + +@pytest.mark.parametrize("cutoff", [2.0, 3.0, 4.0]) +@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_MoS", "graph_LiFePO4", "graph_MoSH"]) +def test_directed_line_graph(graph_data, cutoff, request): + s1, g1, state1 = request.getfixturevalue(graph_data) + bv, bd = compute_pair_vector_and_distance(g1) + g1.edata["bond_vec"] = bv + g1.edata["bond_dist"] = bd + cos_loop = _calculate_cos_loop(g1, cutoff) + theta_loop = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7)) + + line_graph = create_directed_line_graph(g1, cutoff) + line_graph.apply_edges(compute_theta) + + # this test might be lax with just 4 decimal places + np.testing.assert_array_almost_equal(np.sort(theta_loop), np.sort(np.array(line_graph.edata["theta"])), decimal=4) + + +@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_LiFePO4", "graph_MoSH"]) +def test_ensure_directed_line_graph_compat(graph_data, request): + s, g, state = request.getfixturevalue(graph_data) + bv, bd = compute_pair_vector_and_distance(g) + g.edata["bond_vec"] = bv + g.edata["bond_dist"] = bd + line_graph = create_directed_line_graph(g, 3.0) + edge_ids = line_graph.ndata["edge_ids"].clone() + src_bond_sign = line_graph.ndata["src_bond_sign"].clone() + line_graph.ndata["edge_ids"] = torch.zeros(line_graph.num_nodes(), dtype=torch.long) + line_graph.ndata["src_bond_sign"] = torch.zeros(line_graph.num_nodes()) + + assert not torch.allclose(line_graph.ndata["edge_ids"], edge_ids) + assert not torch.allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) + + # test that the line graph is not compatible + line_graph = ensure_directed_line_graph_compatibility(g, line_graph, 3.0) + tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids) + tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) + + with pytest.raises(AssertionError): + ensure_directed_line_graph_compatibility(g, line_graph, 1.0) From 4001318b52fd23177c70071f46c1242e37451d7d Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 11:24:18 -0700 Subject: [PATCH 10/15] add tolerance in lg compatibility --- matgl/graph/compute.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index ed736a65..c1ea20a3 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -212,7 +212,7 @@ def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> def ensure_directed_line_graph_compatibility( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 ) -> dgl.DGLGraph: """Ensure that line graph is compatible with graph. @@ -222,9 +222,14 @@ def ensure_directed_line_graph_compatibility( graph: atomistic graph line_graph: line graph of atomistic graph threebody_cutoff: cutoff for three-body interactions + tol: numerical tolerance for cutoff """ valid_edges = graph.edata["bond_dist"] <= threebody_cutoff - assert line_graph.number_of_nodes() <= sum(valid_edges), "line graph and graph are not compatible" + + # this means there probably is a bond that is just at the cutoff + # this should only really occur when batching graphs + if line_graph.number_of_nodes() > sum(valid_edges): + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] line_graph.ndata["edge_ids"] = edge_ids From e32faa75139868e505985130f81ad4c7fddc204a Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 12:14:14 -0700 Subject: [PATCH 11/15] raise runtime error for incompatible graph --- matgl/graph/compute.py | 4 ++++ tests/graph/test_compute.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index c1ea20a3..43c34eb2 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -231,6 +231,10 @@ def ensure_directed_line_graph_compatibility( if line_graph.number_of_nodes() > sum(valid_edges): valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol + # check again and raise if invalid + if line_graph.number_of_nodes() > sum(valid_edges): + raise RuntimeError("Line graph is not compatible with graph.") + edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] line_graph.ndata["edge_ids"] = edge_ids diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index d175a868..71cb3fae 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -260,5 +260,5 @@ def test_ensure_directed_line_graph_compat(graph_data, request): tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids) tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError): ensure_directed_line_graph_compatibility(g, line_graph, 1.0) From e4753295bced481d799912c933115d92dd10c422 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 14:33:34 -0700 Subject: [PATCH 12/15] single create_line_graph function --- matgl/graph/compute.py | 257 ++++++++++++++++++------------------ tests/graph/test_compute.py | 5 +- 2 files changed, 133 insertions(+), 129 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 43c34eb2..ab20f909 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -10,65 +10,6 @@ import matgl -def compute_3body(g: dgl.DGLGraph): - """Calculate the three body indices from pair atom indices. - - Args: - g: DGL graph - - Returns: - l_g: DGL graph containing three body information from graph - triple_bond_indices (np.ndarray): bond indices that form three-body - n_triple_ij (np.ndarray): number of three-body angles for each bond - n_triple_i (np.ndarray): number of three-body angles each atom - n_triple_s (np.ndarray): number of three-body angles for each structure - """ - n_atoms = g.num_nodes() - first_col = g.edges()[0].cpu().numpy().reshape(-1, 1) - all_indices = np.arange(n_atoms).reshape(1, -1) - n_bond_per_atom = np.count_nonzero(first_col == all_indices, axis=0) - n_triple_i = n_bond_per_atom * (n_bond_per_atom - 1) - n_triple = np.sum(n_triple_i) - n_triple_ij = np.repeat(n_bond_per_atom - 1, n_bond_per_atom) - triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np) # type: ignore - - start = 0 - cs = 0 - for n in n_bond_per_atom: - if n > 0: - """ - triple_bond_indices is generated from all pair permutations of atom indices. The - numpy version below does this with much greater efficiency. The equivalent slow - code is: - - ``` - for j, k in itertools.permutations(range(n), 2): - triple_bond_indices[index] = [start + j, start + k] - ``` - """ - r = np.arange(n) - x, y = np.meshgrid(r, r, indexing="xy") - c = np.stack([y.ravel(), x.ravel()], axis=1) - final = c[c[:, 0] != c[:, 1]] - triple_bond_indices[start : start + (n * (n - 1)), :] = final + cs - start += n * (n - 1) - cs += n - - n_triple_s = [np.sum(n_triple_i[0:n_atoms])] - src_id = torch.tensor(triple_bond_indices[:, 0], dtype=matgl.int_th) - dst_id = torch.tensor(triple_bond_indices[:, 1], dtype=matgl.int_th) - l_g = dgl.graph((src_id, dst_id)) - three_body_id = torch.concatenate(l_g.edges()) - n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th) - max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 - l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] - l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] - l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id] - n_triple_s = torch.tensor(n_triple_s, dtype=matgl.int_th) # type: ignore - return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s - - def compute_pair_vector_and_distance(g: dgl.DGLGraph): """Calculate bond vectors and distances using dgl graphs. @@ -129,84 +70,23 @@ def compute_theta( return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} -def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float): +def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False) -> dgl.DGLGraph: """ Calculate the three body indices from pair atom indices. Args: g: DGL graph threebody_cutoff (float): cutoff for three-body interactions + directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet) Returns: l_g: DGL graph containing three body information from graph """ graph_with_three_body = prune_edges_by_features(g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) - l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body) - return l_g - - -def create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph: - """Creates a line graph from a graph, considers periodic boundary conditions. - - Args: - graph: DGL graph representing atom graph - threebody_cutoff: cutoff for three-body interactions - - Returns: - line_graph: DGL graph line graph of pruned graph to three body cutoff - """ - with torch.no_grad(): - pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) - src_indices, dst_indices = pg.edges() - images = pg.edata["pbc_offset"] - all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0) - num_bonds_per_atom = torch.count_nonzero(src_indices.unsqueeze(dim=1) == all_indices, dim=0) - num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave(num_bonds_per_atom) - lg_src = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) - lg_dst = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) - - incoming_edges = src_indices.unsqueeze(1) == dst_indices - is_self_edge = src_indices == dst_indices - not_self_edge = ~is_self_edge - - n = 0 - # create line graph edges for bonds that are self edges in atom graph - if is_self_edge.any(): - edge_inds_s = is_self_edge.nonzero() - lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge] + 1) - lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() - lg_src_s = lg_src_s[lg_src_s != lg_dst_s] - lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge]) - n = len(lg_dst_s) - lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s - - # create line graph edges for bonds that are not self edges in atom graph - shared_src = src_indices.unsqueeze(1) == src_indices - back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all(-images.unsqueeze(1) == images, axis=2) - incoming = incoming_edges & (shared_src | ~back_tracking) - - edge_inds_ns = not_self_edge.nonzero().squeeze() - lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() - lg_dst_ns = edge_inds_ns.repeat_interleave(num_edges_per_bond[not_self_edge]) - lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns - lg = dgl.graph((lg_src, lg_dst)) - - for key in pg.edata: - lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] - - # we need to store the sign of bond vector when a bond is a src node in the line - # graph in order to appropriately calculate angles when self edges are involved - lg.ndata["src_bond_sign"] = torch.ones( - (lg.number_of_nodes(), 1), dtype=lg.ndata["bond_vec"].dtype, device=lg.device - ) - # if we flip self edges then we need to correct computed angles by pi - angle - # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] - # find the intersection for the rare cases where not all edges end up as nodes in the line graph - all_ns, counts = torch.cat([torch.arange(lg.number_of_nodes(), device=graph.device), edge_inds_ns]).unique( - return_counts=True - ) - lg_inds_ns = all_ns[torch.where(counts > 1)] - lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns] + if directed: + lg = _create_directed_line_graph(graph_with_three_body, threebody_cutoff) + else: + lg, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_3body(graph_with_three_body) return lg @@ -294,3 +174,128 @@ def prune_edges_by_features( new_g.edata[key] = value[valid_edges] return new_g + + +def _compute_3body(g: dgl.DGLGraph): + """Calculate the three body indices from pair atom indices. + + Args: + g: DGL graph + + Returns: + l_g: DGL graph containing three body information from graph + triple_bond_indices (np.ndarray): bond indices that form three-body + n_triple_ij (np.ndarray): number of three-body angles for each bond + n_triple_i (np.ndarray): number of three-body angles each atom + n_triple_s (np.ndarray): number of three-body angles for each structure + """ + n_atoms = g.num_nodes() + first_col = g.edges()[0].cpu().numpy().reshape(-1, 1) + all_indices = np.arange(n_atoms).reshape(1, -1) + n_bond_per_atom = np.count_nonzero(first_col == all_indices, axis=0) + n_triple_i = n_bond_per_atom * (n_bond_per_atom - 1) + n_triple = np.sum(n_triple_i) + n_triple_ij = np.repeat(n_bond_per_atom - 1, n_bond_per_atom) + triple_bond_indices = np.empty((n_triple, 2), dtype=matgl.int_np) # type: ignore + + start = 0 + cs = 0 + for n in n_bond_per_atom: + if n > 0: + """ + triple_bond_indices is generated from all pair permutations of atom indices. The + numpy version below does this with much greater efficiency. The equivalent slow + code is: + + ``` + for j, k in itertools.permutations(range(n), 2): + triple_bond_indices[index] = [start + j, start + k] + ``` + """ + r = np.arange(n) + x, y = np.meshgrid(r, r, indexing="xy") + c = np.stack([y.ravel(), x.ravel()], axis=1) + final = c[c[:, 0] != c[:, 1]] + triple_bond_indices[start : start + (n * (n - 1)), :] = final + cs + start += n * (n - 1) + cs += n + + n_triple_s = [np.sum(n_triple_i[0:n_atoms])] + src_id = torch.tensor(triple_bond_indices[:, 0], dtype=matgl.int_th) + dst_id = torch.tensor(triple_bond_indices[:, 1], dtype=matgl.int_th) + l_g = dgl.graph((src_id, dst_id)) + three_body_id = torch.concatenate(l_g.edges()) + n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th) + max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] + l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] + l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] + l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id] + n_triple_s = torch.tensor(n_triple_s, dtype=matgl.int_th) # type: ignore + return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s + + +def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph: + """Creates a line graph from a graph, considers periodic boundary conditions. + + Args: + graph: DGL graph representing atom graph + threebody_cutoff: cutoff for three-body interactions + + Returns: + line_graph: DGL graph line graph of pruned graph to three body cutoff + """ + with torch.no_grad(): + pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff) + src_indices, dst_indices = pg.edges() + images = pg.edata["pbc_offset"] + all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0) + num_bonds_per_atom = torch.count_nonzero(src_indices.unsqueeze(dim=1) == all_indices, dim=0) + num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave(num_bonds_per_atom) + lg_src = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) + lg_dst = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device) + + incoming_edges = src_indices.unsqueeze(1) == dst_indices + is_self_edge = src_indices == dst_indices + not_self_edge = ~is_self_edge + + n = 0 + # create line graph edges for bonds that are self edges in atom graph + if is_self_edge.any(): + edge_inds_s = is_self_edge.nonzero() + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge] + 1) + lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() + lg_src_s = lg_src_s[lg_src_s != lg_dst_s] + lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge]) + n = len(lg_dst_s) + lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s + + # create line graph edges for bonds that are not self edges in atom graph + shared_src = src_indices.unsqueeze(1) == src_indices + back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all(-images.unsqueeze(1) == images, axis=2) + incoming = incoming_edges & (shared_src | ~back_tracking) + + edge_inds_ns = not_self_edge.nonzero().squeeze() + lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() + lg_dst_ns = edge_inds_ns.repeat_interleave(num_edges_per_bond[not_self_edge]) + lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns + lg = dgl.graph((lg_src, lg_dst)) + + for key in pg.edata: + lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] + + # we need to store the sign of bond vector when a bond is a src node in the line + # graph in order to appropriately calculate angles when self edges are involved + lg.ndata["src_bond_sign"] = torch.ones( + (lg.number_of_nodes(), 1), dtype=lg.ndata["bond_vec"].dtype, device=lg.device + ) + # if we flip self edges then we need to correct computed angles by pi - angle + # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] + # find the intersection for the rare cases where not all edges end up as nodes in the line graph + all_ns, counts = torch.cat([torch.arange(lg.number_of_nodes(), device=graph.device), edge_inds_ns]).unique( + return_counts=True + ) + lg_inds_ns = all_ns[torch.where(counts > 1)] + lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns] + + return lg diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index 71cb3fae..3d4528f3 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -13,7 +13,6 @@ compute_pair_vector_and_distance, compute_theta, compute_theta_and_phi, - create_directed_line_graph, create_line_graph, ensure_directed_line_graph_compatibility, prune_edges_by_features, @@ -233,7 +232,7 @@ def test_directed_line_graph(graph_data, cutoff, request): cos_loop = _calculate_cos_loop(g1, cutoff) theta_loop = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7)) - line_graph = create_directed_line_graph(g1, cutoff) + line_graph = create_line_graph(g1, cutoff, directed=True) line_graph.apply_edges(compute_theta) # this test might be lax with just 4 decimal places @@ -246,7 +245,7 @@ def test_ensure_directed_line_graph_compat(graph_data, request): bv, bd = compute_pair_vector_and_distance(g) g.edata["bond_vec"] = bv g.edata["bond_dist"] = bd - line_graph = create_directed_line_graph(g, 3.0) + line_graph = create_line_graph(g, 3.0, directed=True) edge_ids = line_graph.ndata["edge_ids"].clone() src_bond_sign = line_graph.ndata["src_bond_sign"].clone() line_graph.ndata["edge_ids"] = torch.zeros(line_graph.num_nodes(), dtype=torch.long) From 5bf1d9cadc502f496d36d1fa4e38505a170746a0 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 14:41:38 -0700 Subject: [PATCH 13/15] single ensure lg compatibility function --- matgl/graph/compute.py | 97 +++++++++++++++++++++++++++---------- tests/graph/test_compute.py | 6 +-- 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index ab20f909..220f8857 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -91,8 +91,8 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = return lg -def ensure_directed_line_graph_compatibility( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 +def ensure_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False, tol: float = 5e-7 ) -> dgl.DGLGraph: """Ensure that line graph is compatible with graph. @@ -102,31 +102,13 @@ def ensure_directed_line_graph_compatibility( graph: atomistic graph line_graph: line graph of atomistic graph threebody_cutoff: cutoff for three-body interactions + directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet) tol: numerical tolerance for cutoff """ - valid_edges = graph.edata["bond_dist"] <= threebody_cutoff - - # this means there probably is a bond that is just at the cutoff - # this should only really occur when batching graphs - if line_graph.number_of_nodes() > sum(valid_edges): - valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol - - # check again and raise if invalid - if line_graph.number_of_nodes() > sum(valid_edges): - raise RuntimeError("Line graph is not compatible with graph.") - - edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] - line_graph.ndata["edge_ids"] = edge_ids - - for key in graph.edata: - line_graph.ndata[key] = graph.edata[key][edge_ids] - - src_indices, dst_indices = graph.edges() - ns_edge_ids = (src_indices[edge_ids] != dst_indices[edge_ids]).nonzero().squeeze() - line_graph.ndata["src_bond_sign"] = torch.ones( - (line_graph.number_of_nodes(), 1), dtype=graph.edata["bond_vec"].dtype, device=line_graph.device - ) - line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids] + if directed: + line_graph = _ensure_directed_line_graph_compatibility(graph, line_graph, threebody_cutoff, tol) + else: + line_graph = _ensure_3body_line_graph_compatibility(graph, line_graph, threebody_cutoff) return line_graph @@ -299,3 +281,68 @@ def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns] return lg + + +def _ensure_3body_line_graph_compatibility(graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float): + """Ensure that 3body line graph is compatible with a given graph. + + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + """ + valid_three_body = graph.edata["bond_dist"] <= threebody_cutoff + if line_graph.num_nodes() == graph.edata["bond_vec"][valid_three_body].shape[0]: + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][valid_three_body] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][valid_three_body] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][valid_three_body] + else: + three_body_id = torch.concatenate(line_graph.edges()) + max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][:max_three_body_id] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][:max_three_body_id] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][:max_three_body_id] + + return line_graph + + +def _ensure_directed_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 +) -> dgl.DGLGraph: + """Ensure that line graph is compatible with graph. + + Sets edge data in line graph to be consistent with graph. The line graph is updated in place. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + tol: numerical tolerance for cutoff + """ + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + + # this means there probably is a bond that is just at the cutoff + # this should only really occur when batching graphs + if line_graph.number_of_nodes() > sum(valid_edges): + valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol + + # check again and raise if invalid + if line_graph.number_of_nodes() > sum(valid_edges): + raise RuntimeError("Line graph is not compatible with graph.") + + edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()] + line_graph.ndata["edge_ids"] = edge_ids + + for key in graph.edata: + line_graph.ndata[key] = graph.edata[key][edge_ids] + + src_indices, dst_indices = graph.edges() + ns_edge_ids = (src_indices[edge_ids] != dst_indices[edge_ids]).nonzero().squeeze() + line_graph.ndata["src_bond_sign"] = torch.ones( + (line_graph.number_of_nodes(), 1), dtype=graph.edata["bond_vec"].dtype, device=line_graph.device + ) + line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids] + + return line_graph \ No newline at end of file diff --git a/tests/graph/test_compute.py b/tests/graph/test_compute.py index 3d4528f3..70344d22 100644 --- a/tests/graph/test_compute.py +++ b/tests/graph/test_compute.py @@ -14,7 +14,7 @@ compute_theta, compute_theta_and_phi, create_line_graph, - ensure_directed_line_graph_compatibility, + ensure_line_graph_compatibility, prune_edges_by_features, ) @@ -255,9 +255,9 @@ def test_ensure_directed_line_graph_compat(graph_data, request): assert not torch.allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) # test that the line graph is not compatible - line_graph = ensure_directed_line_graph_compatibility(g, line_graph, 3.0) + line_graph = ensure_line_graph_compatibility(g, line_graph, 3.0, directed=True) tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids) tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign) with pytest.raises(RuntimeError): - ensure_directed_line_graph_compatibility(g, line_graph, 1.0) + ensure_line_graph_compatibility(g, line_graph, 1.0, directed=True) From 1dfcaf07c1e8c3b8d94e1601295ff1c5ac61001f Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 15:06:05 -0700 Subject: [PATCH 14/15] use ensure lg compatibility function in m3gnet --- matgl/models/_m3gnet.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/matgl/models/_m3gnet.py b/matgl/models/_m3gnet.py index 65e392a5..cacbacb9 100644 --- a/matgl/models/_m3gnet.py +++ b/matgl/models/_m3gnet.py @@ -22,6 +22,7 @@ compute_pair_vector_and_distance, compute_theta_and_phi, create_line_graph, + ensure_line_graph_compatibility ) from matgl.layers import ( MLP, @@ -232,17 +233,7 @@ def forward( if l_g is None: l_g = create_line_graph(g, self.threebody_cutoff) else: - valid_three_body = g.edata["bond_dist"] <= self.threebody_cutoff - if l_g.num_nodes() == g.edata["bond_vec"][valid_three_body].shape[0]: - l_g.ndata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - l_g.ndata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] - else: - three_body_id = torch.concatenate(l_g.edges()) - max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 - l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] - l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] + l_g = ensure_line_graph_compatibility(g, l_g, self.threebody_cutoff) l_g.apply_edges(compute_theta_and_phi) g.edata["rbf"] = expanded_dists three_body_basis = self.basis_expansion(l_g) From 52d1e17590045ff8d1d4f5c955f7d3bd1e2f2854 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 2 Oct 2023 15:07:54 -0700 Subject: [PATCH 15/15] ruff --- matgl/graph/compute.py | 2 +- matgl/models/_m3gnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index 220f8857..1d724cfb 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -345,4 +345,4 @@ def _ensure_directed_line_graph_compatibility( ) line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids] - return line_graph \ No newline at end of file + return line_graph diff --git a/matgl/models/_m3gnet.py b/matgl/models/_m3gnet.py index cacbacb9..0ea981e0 100644 --- a/matgl/models/_m3gnet.py +++ b/matgl/models/_m3gnet.py @@ -22,7 +22,7 @@ compute_pair_vector_and_distance, compute_theta_and_phi, create_line_graph, - ensure_line_graph_compatibility + ensure_line_graph_compatibility, ) from matgl.layers import ( MLP,