Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set pbc_offsift and pos as float64 #153

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions matgl/graph/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def get_graph_from_processed_structure(
"""
u, v = torch.tensor(src_id), torch.tensor(dst_id)
g = dgl.graph((u, v), num_nodes=len(structure))
pbc_offset = torch.tensor(images)
pbc_offset = torch.tensor(images, dtype=torch.float64)
g.edata["pbc_offset"] = pbc_offset.to(matgl.int_th)
g.edata["pbc_offshift"] = torch.matmul(pbc_offset, torch.tensor(lattice_matrix[0])).to(matgl.float_th)
# Note: pbc_ offshift and pos needs to be float64 to handle cases where bonds are exactly at cutoff
g.edata["pbc_offshift"] = torch.matmul(pbc_offset, torch.tensor(lattice_matrix[0]))
g.edata["lattice"] = torch.tensor(np.repeat(lattice_matrix, g.num_edges(), axis=0), dtype=matgl.float_th)
element_to_index = {elem: idx for idx, elem in enumerate(element_types)}
node_type = (
Expand All @@ -62,6 +63,6 @@ def get_graph_from_processed_structure(
else np.array([element_to_index[elem] for elem in structure.get_chemical_symbols()])
)
g.ndata["node_type"] = torch.tensor(node_type, dtype=matgl.int_th)
g.ndata["pos"] = torch.tensor(cart_coords, dtype=matgl.float_th)
g.ndata["pos"] = torch.tensor(cart_coords, dtype=torch.float64)
state_attr = np.array([0.0, 0.0]).astype(matgl.float_np)
return g, state_attr
28 changes: 28 additions & 0 deletions tests/graph/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np
import torch
from pymatgen.core import Lattice, Structure

from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.compute import (
compute_pair_vector_and_distance,
compute_theta,
Expand Down Expand Up @@ -158,3 +160,29 @@ def test_compute_three_body(self, graph_AcAla3NHMe):
line_graph = create_line_graph(g1, 5.0)
line_graph.apply_edges(compute_theta_and_phi)
np.testing.assert_allclose(line_graph.edata["triple_bond_lengths"].detach().numpy()[0], 1.777829)


def test_line_graph_extensive():
structure = Structure.from_spacegroup("Fm-3m", Lattice.cubic(6.0 / np.sqrt(2)), ["Fe"], [[0, 0, 0]])

element_types = get_element_list([structure])
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
g1, _ = converter.get_graph(structure)
bond_vec, bond_dist = compute_pair_vector_and_distance(g1)
g1.edata["bond_dist"] = bond_dist
g1.edata["bond_vec"] = bond_vec

supercell = structure.copy()
supercell.make_supercell([2, 1, 1])
g2, _ = converter.get_graph(supercell)
bond_vec, bond_dist = compute_pair_vector_and_distance(g2)
g2.edata["bond_dist"] = bond_dist
g2.edata["bond_vec"] = bond_vec

lg1 = create_line_graph(g1, 3.0)
lg2 = create_line_graph(g2, 3.0)

assert 2 * g1.number_of_nodes() == g2.number_of_nodes()
assert 2 * g1.number_of_edges() == g2.number_of_edges()
assert 2 * lg1.number_of_nodes() == lg2.number_of_nodes()
assert 2 * lg1.number_of_edges() == lg2.number_of_edges()