Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directed line graph #178

Merged
merged 15 commits into from
Oct 3, 2023
314 changes: 260 additions & 54 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,164 @@
"""Computing various graph based operations."""
from __future__ import annotations

from typing import Callable

import dgl
import numpy as np
import torch

import matgl


def compute_3body(g: dgl.DGLGraph):
def compute_pair_vector_and_distance(g: dgl.DGLGraph):
"""Calculate bond vectors and distances using dgl graphs.

Args:
g: DGL graph

Returns:
bond_vec (torch.tensor): bond distance between two atoms
bond_dist (torch.tensor): vector from src node to dst node
"""
dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["pbc_offshift"]
src_pos = g.ndata["pos"][g.edges()[0]]
bond_vec = (dst_pos - src_pos).float()
bond_dist = torch.norm(bond_vec, dim=1)

return bond_vec, bond_dist


def compute_theta_and_phi(edges: dgl.udf.EdgeBatch):
"""Calculate bond angle Theta and Phi using dgl graphs.

Args:
edges: DGL graph edges

Returns:
cos_theta: torch.Tensor
phi: torch.Tensor
triple_bond_lengths (torch.tensor):
"""
angles = compute_theta(edges, cosine=True, directed=False)
angles["phi"] = torch.zeros_like(angles["cos_theta"])
return angles


def compute_theta(
edges: dgl.udf.EdgeBatch, cosine: bool = False, directed: bool = True, eps=1e-7
) -> dict[str, torch.Tensor]:
"""User defined dgl function to calculate bond angles from edges in a graph.

Args:
edges: DGL graph edges
cosine: Whether to return the cosine of the angle or the angle itself
directed: Whether to the line graph was created with create directed line graph.
In which case bonds (only those that are not self bonds) need to
have their bond vectors flipped.
eps: eps value used to clamp cosine values to avoid acos of values > 1.0

Returns:
dict[str, torch.Tensor]: Dictionary containing bond angles and distances
"""
vec1 = edges.src["bond_vec"] * edges.src["src_bond_sign"] if directed else edges.src["bond_vec"]
vec2 = edges.dst["bond_vec"]
key = "cos_theta" if cosine else "theta"
val = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1))
val = val.clamp_(min=-1 + eps, max=1 - eps) # stability for floating point numbers > 1.0
if not cosine:
val = torch.acos(val)
return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]}


def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False) -> dgl.DGLGraph:
"""
Calculate the three body indices from pair atom indices.

Args:
g: DGL graph
threebody_cutoff (float): cutoff for three-body interactions
directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet)

Returns:
l_g: DGL graph containing three body information from graph
"""
graph_with_three_body = prune_edges_by_features(g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff)
if directed:
lg = _create_directed_line_graph(graph_with_three_body, threebody_cutoff)
else:
lg, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_3body(graph_with_three_body)

return lg


def ensure_line_graph_compatibility(
graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False, tol: float = 5e-7
) -> dgl.DGLGraph:
"""Ensure that line graph is compatible with graph.

Sets edge data in line graph to be consistent with graph. The line graph is updated in place.

Args:
graph: atomistic graph
line_graph: line graph of atomistic graph
threebody_cutoff: cutoff for three-body interactions
directed (bool): Whether to create a directed line graph, or an m3gnet 3body line graph (default: False, m3gnet)
tol: numerical tolerance for cutoff
"""
if directed:
line_graph = _ensure_directed_line_graph_compatibility(graph, line_graph, threebody_cutoff, tol)
else:
line_graph = _ensure_3body_line_graph_compatibility(graph, line_graph, threebody_cutoff)

return line_graph


def prune_edges_by_features(
graph: dgl.DGLGraph,
feat_name: str,
condition: Callable[[torch.Tensor], torch.Tensor],
keep_ndata: bool = False,
keep_edata: bool = True,
*args,
**kwargs,
) -> dgl.DGLGraph:
"""Removes edges graph that do satisfy given condition based on a specified feature value.

Returns a new graph with edges removed.

Args:
graph: DGL graph
feat_name: edge field name
condition: condition function. Must be a function where the first is the value
of the edge field data and returns a Tensor of boolean values.
keep_ndata: whether to keep node features
keep_edata: whether to keep edge features
*args: additional arguments to pass to condition function
**kwargs: additional keyword arguments to pass to condition function

Returns: dgl.Graph with removed edges.
"""
if feat_name not in graph.edata:
raise ValueError(f"Edge field {feat_name} not an edge feature in given graph.")

valid_edges = torch.logical_not(condition(graph.edata[feat_name], *args, **kwargs))
src, dst = graph.edges()
src, dst = src[valid_edges], dst[valid_edges]
e_ids = valid_edges.nonzero().squeeze()
new_g = dgl.graph((src, dst), device=graph.device)
new_g.edata["edge_ids"] = e_ids # keep track of original edge ids

if keep_ndata:
for key, value in graph.ndata.items():
new_g.ndata[key] = value
if keep_edata:
for key, value in graph.edata.items():
new_g.edata[key] = value[valid_edges]

return new_g


def _compute_3body(g: dgl.DGLGraph):
"""Calculate the three body indices from pair atom indices.

Args:
Expand Down Expand Up @@ -67,76 +217,132 @@ def compute_3body(g: dgl.DGLGraph):
return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s


def compute_pair_vector_and_distance(g: dgl.DGLGraph):
"""Calculate bond vectors and distances using dgl graphs.
def _create_directed_line_graph(graph: dgl.DGLGraph, threebody_cutoff: float) -> dgl.DGLGraph:
"""Creates a line graph from a graph, considers periodic boundary conditions.

Args:
g: DGL graph
graph: DGL graph representing atom graph
threebody_cutoff: cutoff for three-body interactions

Returns:
bond_vec (torch.tensor): bond distance between two atoms
bond_dist (torch.tensor): vector from src node to dst node
line_graph: DGL graph line graph of pruned graph to three body cutoff
"""
dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["pbc_offshift"]
src_pos = g.ndata["pos"][g.edges()[0]]
bond_vec = (dst_pos - src_pos).float()
bond_dist = torch.norm(bond_vec, dim=1)
with torch.no_grad():
pg = prune_edges_by_features(graph, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff)
src_indices, dst_indices = pg.edges()
images = pg.edata["pbc_offset"]
all_indices = torch.arange(pg.number_of_nodes(), device=graph.device).unsqueeze(dim=0)
num_bonds_per_atom = torch.count_nonzero(src_indices.unsqueeze(dim=1) == all_indices, dim=0)
num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave(num_bonds_per_atom)
lg_src = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device)
lg_dst = torch.empty(num_edges_per_bond.sum(), dtype=matgl.int_th, device=graph.device)

return bond_vec, bond_dist
incoming_edges = src_indices.unsqueeze(1) == dst_indices
is_self_edge = src_indices == dst_indices
not_self_edge = ~is_self_edge

n = 0
# create line graph edges for bonds that are self edges in atom graph
if is_self_edge.any():
edge_inds_s = is_self_edge.nonzero()
lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge] + 1)
lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze()
lg_src_s = lg_src_s[lg_src_s != lg_dst_s]
lg_dst_s = edge_inds_s.repeat_interleave(num_edges_per_bond[is_self_edge])
n = len(lg_dst_s)
lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s

def compute_theta_and_phi(edges: dgl.udf.EdgeBatch):
"""Calculate bond angle Theta and Phi using dgl graphs.
# create line graph edges for bonds that are not self edges in atom graph
shared_src = src_indices.unsqueeze(1) == src_indices
back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all(-images.unsqueeze(1) == images, axis=2)
incoming = incoming_edges & (shared_src | ~back_tracking)

Args:
edges: DGL graph edges
edge_inds_ns = not_self_edge.nonzero().squeeze()
lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze()
lg_dst_ns = edge_inds_ns.repeat_interleave(num_edges_per_bond[not_self_edge])
lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns
lg = dgl.graph((lg_src, lg_dst))

Returns:
cos_theta: torch.Tensor
phi: torch.Tensor
triple_bond_lengths (torch.tensor):
"""
angles = compute_theta(edges, cosine=True)
angles["phi"] = torch.zeros_like(angles["cos_theta"])
return angles
for key in pg.edata:
lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()]

# we need to store the sign of bond vector when a bond is a src node in the line
# graph in order to appropriately calculate angles when self edges are involved
lg.ndata["src_bond_sign"] = torch.ones(
(lg.number_of_nodes(), 1), dtype=lg.ndata["bond_vec"].dtype, device=lg.device
)
# if we flip self edges then we need to correct computed angles by pi - angle
# lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s]
# find the intersection for the rare cases where not all edges end up as nodes in the line graph
all_ns, counts = torch.cat([torch.arange(lg.number_of_nodes(), device=graph.device), edge_inds_ns]).unique(
return_counts=True
)
lg_inds_ns = all_ns[torch.where(counts > 1)]
lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][lg_inds_ns]

def compute_theta(edges: dgl.udf.EdgeBatch, cosine: bool = False) -> dict[str, torch.Tensor]:
"""User defined dgl function to calculate bond angles from edges in a graph.
return lg

Args:
edges: DGL graph edges
cosine: Whether to return the cosine of the angle or the angle itself

Returns:
dict[str, torch.Tensor]: Dictionary containing bond angles and distances
"""
vec1 = edges.src["bond_vec"]
vec2 = edges.dst["bond_vec"]
key = "cos_theta" if cosine else "theta"
val = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1))
if not cosine:
val = torch.acos(val)
return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]}
def _ensure_3body_line_graph_compatibility(graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float):
"""Ensure that 3body line graph is compatible with a given graph.

Sets edge data in line graph to be consistent with graph. The line graph is updated in place.

def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float):
Args:
graph: atomistic graph
line_graph: line graph of atomistic graph
threebody_cutoff: cutoff for three-body interactions
"""
Calculate the three body indices from pair atom indices.
valid_three_body = graph.edata["bond_dist"] <= threebody_cutoff
if line_graph.num_nodes() == graph.edata["bond_vec"][valid_three_body].shape[0]:
line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][valid_three_body]
line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][valid_three_body]
line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][valid_three_body]
else:
three_body_id = torch.concatenate(line_graph.edges())
max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0
line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][:max_three_body_id]
line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][:max_three_body_id]
line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][:max_three_body_id]

Args:
g: DGL graph
threebody_cutoff (float): cutoff for three-body interactions
return line_graph

Returns:
l_g: DGL graph containing three body information from graph

def _ensure_directed_line_graph_compatibility(
graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7
) -> dgl.DGLGraph:
"""Ensure that line graph is compatible with graph.

Sets edge data in line graph to be consistent with graph. The line graph is updated in place.

Args:
graph: atomistic graph
line_graph: line graph of atomistic graph
threebody_cutoff: cutoff for three-body interactions
tol: numerical tolerance for cutoff
"""
valid_three_body = g.edata["bond_dist"] <= threebody_cutoff
src_id_with_three_body = g.edges()[0][valid_three_body]
dst_id_with_three_body = g.edges()[1][valid_three_body]
graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body))
graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body)
return l_g
valid_edges = graph.edata["bond_dist"] <= threebody_cutoff

# this means there probably is a bond that is just at the cutoff
# this should only really occur when batching graphs
if line_graph.number_of_nodes() > sum(valid_edges):
valid_edges = graph.edata["bond_dist"] <= threebody_cutoff + tol

# check again and raise if invalid
if line_graph.number_of_nodes() > sum(valid_edges):
raise RuntimeError("Line graph is not compatible with graph.")

edge_ids = valid_edges.nonzero().squeeze()[: line_graph.number_of_nodes()]
line_graph.ndata["edge_ids"] = edge_ids

for key in graph.edata:
line_graph.ndata[key] = graph.edata[key][edge_ids]

src_indices, dst_indices = graph.edges()
ns_edge_ids = (src_indices[edge_ids] != dst_indices[edge_ids]).nonzero().squeeze()
line_graph.ndata["src_bond_sign"] = torch.ones(
(line_graph.number_of_nodes(), 1), dtype=graph.edata["bond_vec"].dtype, device=line_graph.device
)
line_graph.ndata["src_bond_sign"][ns_edge_ids] = -line_graph.ndata["src_bond_sign"][ns_edge_ids]

return line_graph
13 changes: 2 additions & 11 deletions matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
compute_pair_vector_and_distance,
compute_theta_and_phi,
create_line_graph,
ensure_line_graph_compatibility,
)
from matgl.layers import (
MLP,
Expand Down Expand Up @@ -232,17 +233,7 @@ def forward(
if l_g is None:
l_g = create_line_graph(g, self.threebody_cutoff)
else:
valid_three_body = g.edata["bond_dist"] <= self.threebody_cutoff
if l_g.num_nodes() == g.edata["bond_vec"][valid_three_body].shape[0]:
l_g.ndata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
l_g.ndata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
else:
three_body_id = torch.concatenate(l_g.edges())
max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0
l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id]
l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id]
l_g = ensure_line_graph_compatibility(g, l_g, self.threebody_cutoff)
l_g.apply_edges(compute_theta_and_phi)
g.edata["rbf"] = expanded_dists
three_body_basis = self.basis_expansion(l_g)
Expand Down
Loading