Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Sep 29, 2023
1 parent 4dc6038 commit 0f97de3
Showing 1 changed file with 85 additions and 9 deletions.
94 changes: 85 additions & 9 deletions tests/graph/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
from functools import partial

import numpy as np
import pytest
import torch
import torch.testing as tt
from pymatgen.core import Lattice, Structure

from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.compute import (
compute_pair_vector_and_distance,
compute_theta,
compute_theta_and_phi,
create_directed_line_graph,
create_line_graph,
ensure_directed_line_graph_compatibility,
prune_edges_by_features,
)


Expand Down Expand Up @@ -46,8 +51,8 @@ def _calculate_cos_loop(graph, threebody_cutoff=4.0):
for j in range(n_site):
if i == j:
continue
vi = graph.edata["bond_vec"][i + start_index].numpy()
vj = graph.edata["bond_vec"][j + start_index].numpy()
vi = graph.edata["bond_vec"][i + start_index].detach().numpy()
vj = graph.edata["bond_vec"][j + start_index].detach().numpy()
di = np.linalg.norm(vi)
dj = np.linalg.norm(vj)
if (di <= threebody_cutoff) and (dj <= threebody_cutoff):
Expand Down Expand Up @@ -114,14 +119,13 @@ def test_compute_angle(self, graph_Mo, graph_CH4):
)

# test only compute theta
line_graph.apply_edges(compute_theta)
np.testing.assert_array_almost_equal(
np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"])), decimal=4
)
line_graph.apply_edges(partial(compute_theta, directed=False))
theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))
np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata["theta"])), decimal=4)

# test only compute theta with cosine
_ = line_graph.edata.pop("cos_theta")
line_graph.apply_edges(partial(compute_theta, cosine=True))
line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))
np.testing.assert_array_almost_equal(
np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"]))
)
Expand All @@ -140,14 +144,14 @@ def test_compute_angle(self, graph_Mo, graph_CH4):
)

# test only compute theta
line_graph.apply_edges(compute_theta)
line_graph.apply_edges(partial(compute_theta, directed=False))
np.testing.assert_array_almost_equal(
np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata["theta"]))
)

# test only compute theta with cosine
_ = line_graph.edata.pop("cos_theta")
line_graph.apply_edges(partial(compute_theta, cosine=True))
line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))
np.testing.assert_array_almost_equal(
np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata["cos_theta"]))
)
Expand Down Expand Up @@ -186,3 +190,75 @@ def test_line_graph_extensive():
assert 2 * g1.number_of_edges() == g2.number_of_edges()
assert 2 * lg1.number_of_nodes() == lg2.number_of_nodes()
assert 2 * lg1.number_of_edges() == lg2.number_of_edges()


@pytest.mark.parametrize("keep_ndata", [True, False])
@pytest.mark.parametrize("keep_edata", [True, False])
def test_remove_edges_by_features(graph_Mo, keep_ndata, keep_edata):
s1, g1, state1 = graph_Mo
bv, bd = compute_pair_vector_and_distance(g1)
g1.edata["bond_vec"] = bv
g1.edata["bond_dist"] = bd

new_cutoff = 3.0
converter = Structure2Graph(element_types=get_element_list([s1]), cutoff=new_cutoff)
g2, state2 = converter.get_graph(s1)

# remove edges by features
new_g = prune_edges_by_features(
g1, "bond_dist", condition=lambda x: x > new_cutoff, keep_ndata=keep_ndata, keep_edata=keep_edata
)
valid_edges = g1.edata["bond_dist"] <= new_cutoff

assert new_g.num_edges() == g2.num_edges()
assert new_g.num_nodes() == g2.num_nodes()
assert torch.allclose(new_g.edata["edge_ids"], valid_edges.nonzero().squeeze())

if keep_ndata:
assert new_g.ndata.keys() == g1.ndata.keys()

if keep_edata:
for key in g1.edata:
if key != "edge_ids":
assert torch.allclose(new_g.edata[key], g1.edata[key][valid_edges])


@pytest.mark.parametrize("cutoff", [2.0, 3.0, 4.0])
@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_MoS", "graph_LiFePO4", "graph_MoSH"])
def test_directed_line_graph(graph_data, cutoff, request):
s1, g1, state1 = request.getfixturevalue(graph_data)
bv, bd = compute_pair_vector_and_distance(g1)
g1.edata["bond_vec"] = bv
g1.edata["bond_dist"] = bd
cos_loop = _calculate_cos_loop(g1, cutoff)
theta_loop = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))

line_graph = create_directed_line_graph(g1, cutoff)
line_graph.apply_edges(compute_theta)

# this test might be lax with just 4 decimal places
np.testing.assert_array_almost_equal(np.sort(theta_loop), np.sort(np.array(line_graph.edata["theta"])), decimal=4)


@pytest.mark.parametrize("graph_data", ["graph_Mo", "graph_CH4", "graph_LiFePO4", "graph_MoSH"])
def test_ensure_directed_line_graph_compat(graph_data, request):
s, g, state = request.getfixturevalue(graph_data)
bv, bd = compute_pair_vector_and_distance(g)
g.edata["bond_vec"] = bv
g.edata["bond_dist"] = bd
line_graph = create_directed_line_graph(g, 3.0)
edge_ids = line_graph.ndata["edge_ids"].clone()
src_bond_sign = line_graph.ndata["src_bond_sign"].clone()
line_graph.ndata["edge_ids"] = torch.zeros(line_graph.num_nodes(), dtype=torch.long)
line_graph.ndata["src_bond_sign"] = torch.zeros(line_graph.num_nodes())

assert not torch.allclose(line_graph.ndata["edge_ids"], edge_ids)
assert not torch.allclose(line_graph.ndata["src_bond_sign"], src_bond_sign)

# test that the line graph is not compatible
line_graph = ensure_directed_line_graph_compatibility(g, line_graph, 3.0)
tt.assert_allclose(line_graph.ndata["edge_ids"], edge_ids)
tt.assert_allclose(line_graph.ndata["src_bond_sign"], src_bond_sign)

with pytest.raises(AssertionError):
ensure_directed_line_graph_compatibility(g, line_graph, 1.0)

0 comments on commit 0f97de3

Please sign in to comment.