Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge changes during the E3 Hamiltonian model development into main #85

Merged
merged 18 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import AtomicDataDict
from .util import _TORCH_INTEGER_DTYPES
from dptb.utils.torch_geometric.data import Data
from dptb.utils.constants import atomic_num_dict

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]
Expand Down Expand Up @@ -354,7 +355,7 @@ def __init__(
def from_points(
cls,
pos=None,
r_max: float = None,
r_max: Union[float, int, dict] = None,
self_interaction: bool = False,
cell=None,
pbc: Optional[PBC] = None,
Expand Down Expand Up @@ -882,6 +883,18 @@ def neighbor_list_and_relative_vec(
if isinstance(pbc, bool):
pbc = (pbc,) * 3

mask_r = False
if isinstance(r_max, dict):
_r_max = max(r_max.values())
if _r_max - min(r_max.values()) > 1e-5:
mask_r = True

if len(r_max) < len(set(atomic_numbers)):
raise ValueError("The number of r_max is less than the number of required atom species.")
else:
_r_max = r_max
assert isinstance(r_max, (float, int))

# Either the position or the cell may be on the GPU as tensors
if isinstance(pos, torch.Tensor):
temp_pos = pos.detach().cpu().numpy()
Expand Down Expand Up @@ -918,7 +931,7 @@ def neighbor_list_and_relative_vec(
pbc,
temp_cell,
temp_pos,
cutoff=float(r_max),
cutoff=float(_r_max),
self_interaction=self_interaction, # we want edges from atom to itself in different periodic images!
use_scaled_positions=False,
)
Expand Down Expand Up @@ -990,4 +1003,34 @@ def neighbor_list_and_relative_vec(
(torch.LongTensor(first_idex), torch.LongTensor(second_idex))
)

# TODO: mask the edges that is larger than r_max
if mask_r:
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if cell is not None:
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
shifts,
cell_tensor.reshape(3,3), # remove batch dimension
)

edge_length = torch.linalg.norm(edge_vec, dim=-1)

atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
for i in set(atomic_numbers):
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])
r_mask = edge_length <= edge_length_max
if any(~r_mask):
edge_index = edge_index[:, r_mask]
shifts = shifts[r_mask]

del edge_length
del edge_vec
del r_map
del edge_length_max
del r_mask

return edge_index, shifts, cell_tensor
30 changes: 21 additions & 9 deletions dptb/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from importlib import import_module

from dptb.data.dataset import DefaultDataset
from dptb.data.dataset._deeph_dataset import DeePHE3Dataset
from dptb import data
from dptb.data.transforms import TypeMapper, OrbitalMapper
from dptb.data import AtomicDataset, register_fields
Expand Down Expand Up @@ -133,7 +134,7 @@ def build_dataset(set_options, common_options):
"""
dataset_type = set_options.get("type", "DefaultDataset")

if dataset_type == "DefaultDataset":
if dataset_type in ["DefaultDataset", "DeePHDataset"]:
# See if we can get a OrbitalMapper.
if "basis" in common_options:
idp = OrbitalMapper(common_options["basis"])
Expand Down Expand Up @@ -189,15 +190,26 @@ def build_dataset(set_options, common_options):
# The order itself is not important, but must be consistant for the same list.
info_files = {key: info_files[key] for key in sorted(info_files)}

dataset = DefaultDataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)
if dataset_type == "DeePHDataset":
dataset = DeePHE3Dataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)
else:
dataset = DefaultDataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_overlap=set_options.get("get_overlap", False),
get_DM=set_options.get("get_DM", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)

else:
raise ValueError(f"Not support dataset type: {type}.")

return dataset
return dataset
Loading
Loading