Skip to content

Commit

Permalink
Merge changes during the E3 Hamiltonian model development into main (#85
Browse files Browse the repository at this point in the history
)

* update recent

* add write json v1

* update new model

* update model and add support for deeph dataset

* update seperate layer norm and extra output in tp

* track temp

* update  loss to support onsite shift

* update argcheck for loss options

* update argcheck for loss options

* update loss for onsite shift

* shuffle val loader

* update setinfo's argcheck

* only get diag element mask in e3tb

* update python dependency back to 3.8

* update argcheck for eig loss, and remove repeat code in default dataset, rename tp.py

* remove comment in data/build

* update error raise in test build dataset
  • Loading branch information
floatingCatty authored Mar 19, 2024
1 parent 143274a commit 45e9e87
Show file tree
Hide file tree
Showing 37 changed files with 6,447 additions and 864 deletions.
47 changes: 45 additions & 2 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import AtomicDataDict
from .util import _TORCH_INTEGER_DTYPES
from dptb.utils.torch_geometric.data import Data
from dptb.utils.constants import atomic_num_dict

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]
Expand Down Expand Up @@ -354,7 +355,7 @@ def __init__(
def from_points(
cls,
pos=None,
r_max: float = None,
r_max: Union[float, int, dict] = None,
self_interaction: bool = False,
cell=None,
pbc: Optional[PBC] = None,
Expand Down Expand Up @@ -882,6 +883,18 @@ def neighbor_list_and_relative_vec(
if isinstance(pbc, bool):
pbc = (pbc,) * 3

mask_r = False
if isinstance(r_max, dict):
_r_max = max(r_max.values())
if _r_max - min(r_max.values()) > 1e-5:
mask_r = True

if len(r_max) < len(set(atomic_numbers)):
raise ValueError("The number of r_max is less than the number of required atom species.")
else:
_r_max = r_max
assert isinstance(r_max, (float, int))

# Either the position or the cell may be on the GPU as tensors
if isinstance(pos, torch.Tensor):
temp_pos = pos.detach().cpu().numpy()
Expand Down Expand Up @@ -918,7 +931,7 @@ def neighbor_list_and_relative_vec(
pbc,
temp_cell,
temp_pos,
cutoff=float(r_max),
cutoff=float(_r_max),
self_interaction=self_interaction, # we want edges from atom to itself in different periodic images!
use_scaled_positions=False,
)
Expand Down Expand Up @@ -990,4 +1003,34 @@ def neighbor_list_and_relative_vec(
(torch.LongTensor(first_idex), torch.LongTensor(second_idex))
)

# TODO: mask the edges that is larger than r_max
if mask_r:
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if cell is not None:
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
shifts,
cell_tensor.reshape(3,3), # remove batch dimension
)

edge_length = torch.linalg.norm(edge_vec, dim=-1)

atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
for i in set(atomic_numbers):
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])
r_mask = edge_length <= edge_length_max
if any(~r_mask):
edge_index = edge_index[:, r_mask]
shifts = shifts[r_mask]

del edge_length
del edge_vec
del r_map
del edge_length_max
del r_mask

return edge_index, shifts, cell_tensor
30 changes: 21 additions & 9 deletions dptb/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from importlib import import_module

from dptb.data.dataset import DefaultDataset
from dptb.data.dataset._deeph_dataset import DeePHE3Dataset
from dptb import data
from dptb.data.transforms import TypeMapper, OrbitalMapper
from dptb.data import AtomicDataset, register_fields
Expand Down Expand Up @@ -133,7 +134,7 @@ def build_dataset(set_options, common_options):
"""
dataset_type = set_options.get("type", "DefaultDataset")

if dataset_type == "DefaultDataset":
if dataset_type in ["DefaultDataset", "DeePHDataset"]:
# See if we can get a OrbitalMapper.
if "basis" in common_options:
idp = OrbitalMapper(common_options["basis"])
Expand Down Expand Up @@ -189,15 +190,26 @@ def build_dataset(set_options, common_options):
# The order itself is not important, but must be consistant for the same list.
info_files = {key: info_files[key] for key in sorted(info_files)}

dataset = DefaultDataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)
if dataset_type == "DeePHDataset":
dataset = DeePHE3Dataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)
else:
dataset = DefaultDataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_overlap=set_options.get("get_overlap", False),
get_DM=set_options.get("get_DM", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
info_files = info_files
)

else:
raise ValueError(f"Not support dataset type: {type}.")

return dataset
return dataset
Loading

0 comments on commit 45e9e87

Please sign in to comment.