Skip to content

Commit

Permalink
add E3 features (node/edge) to hamiltonian/density blocks (#129)
Browse files Browse the repository at this point in the history
* stack changes

* fix test ham to feature
  • Loading branch information
floatingCatty authored Apr 15, 2024
1 parent 40a408c commit 48d89f8
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 23 deletions.
4 changes: 2 additions & 2 deletions dptb/data/dataset/_abacus_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..transforms import TypeMapper, OrbitalMapper
from ._base_datasets import AtomicDataset, AtomicInMemoryDataset
#from dptb.nn.hamiltonian import E3Hamiltonian
from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
from dptb.data.interfaces.ham_to_feature import block_to_feature

orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}

Expand All @@ -32,7 +32,7 @@ def _abacus_h5_reader(h5file_path, AtomicData_options):
basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
idp = OrbitalMapper(basis)
# e3 = E3Hamiltonian(idp=idp, decompose=True)
ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
# with torch.no_grad():
# atomic_data = e3(atomic_data.to_dict())
# atomic_data = AtomicData.from_dict(atomic_data)
Expand Down
4 changes: 2 additions & 2 deletions dptb/data/dataset/_abacus_dataset_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..transforms import TypeMapper, OrbitalMapper
from ._base_datasets import AtomicInMemoryDataset
from dptb.nn.hamiltonian import E3Hamiltonian
from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
from dptb.data.interfaces.ham_to_feature import block_to_feature
from dptb.data.interfaces.abacus import recursive_parse

orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"}
Expand All @@ -32,7 +32,7 @@ def _abacus_h5_reader(h5file_path, AtomicData_options):
basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
idp = OrbitalMapper(basis)
# e3 = E3Hamiltonian(idp=idp, decompose=True)
ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
# with torch.no_grad():
# atomic_data = e3(atomic_data.to_dict())
# atomic_data = AtomicData.from_dict(atomic_data)
Expand Down
4 changes: 2 additions & 2 deletions dptb/data/dataset/_default_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..transforms import TypeMapper, OrbitalMapper
from ._base_datasets import AtomicDataset, AtomicInMemoryDataset
#from dptb.nn.hamiltonian import E3Hamiltonian
from dptb.data.interfaces.ham_to_feature import ham_block_to_feature
from dptb.data.interfaces.ham_to_feature import block_to_feature
from dptb.utils.tools import j_loader
from dptb.data.AtomicDataDict import with_edge_vectors
from dptb.nn.hamiltonian import E3Hamiltonian
Expand Down Expand Up @@ -201,7 +201,7 @@ def toAtomicDataList(self, idp: TypeMapper = None):
overlaps = False
# e3 = E3Hamiltonian(idp=idp, decompose=True)
if features != False or overlaps != False:
ham_block_to_feature(atomic_data, idp, features, overlaps)
block_to_feature(atomic_data, idp, features, overlaps)

if not hasattr(atomic_data, AtomicDataDict.EDGE_FEATURES_KEY):
# TODO: initialize the edge and node feature tempretely, there should be a better way.
Expand Down
139 changes: 125 additions & 14 deletions dptb/data/interfaces/ham_to_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import h5py
import logging
from dptb.utils.constants import anglrMId, OPENMX2DeePTB
from dptb.data import AtomicData, AtomicDataDict

log = logging.getLogger(__name__)

def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=False):
def block_to_feature(data, idp, blocks=False, overlap_blocks=False):
# Hamiltonian_blocks should be a h5 group in the current version
assert Hamiltonian_blocks != False or overlap_blocks!=False, "Both Hamiltonian and overlap blocks are not provided."
assert blocks != False or overlap_blocks!=False, "Both feature block and overlap blocks are not provided."

if Hamiltonian_blocks:
if blocks:
onsite_ham = []
edge_ham = []
if overlap_blocks:
Expand All @@ -23,17 +24,25 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal
idp.get_orbital_maps()
idp.get_orbpair_maps()

if isinstance(data, AtomicData):
if not hasattr(data, _keys.ATOMIC_NUMBERS_KEY):
setattr(data, _keys.ATOMIC_NUMBERS_KEY, idp.untransform(data[_keys.ATOM_TYPE_KEY]))
if isinstance(data, dict):
if not data.get(_keys.ATOMIC_NUMBERS_KEY, None):
data[_keys.ATOMIC_NUMBERS_KEY] = idp.untransform(data[_keys.ATOM_TYPE_KEY])
atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY]

# onsite features
if Hamiltonian_blocks:
if blocks:
for atom in range(len(atomic_numbers)):
block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0]))))
try:
block = Hamiltonian_blocks[block_index]
block = blocks[block_index]
except:
raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.")

if isinstance(block, torch.Tensor):
block = block.cpu().detach().numpy()
symbol = ase.data.chemical_symbols[atomic_numbers[atom]]
basis_list = idp.basis[symbol]
onsite_out = np.zeros(idp.reduced_matrix_element)
Expand All @@ -60,6 +69,7 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal

for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift):
block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift))))
r_index = '_'.join(map(str, map(int, [atom_j+1, atom_i+1] + list(-R_shift))))
symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]]
symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]]

Expand All @@ -69,24 +79,37 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal
# block_s = overlap_blocks[block_index]
# except:
# raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.")
if Hamiltonian_blocks:
block = Hamiltonian_blocks.get(block_index, 0)
if block == 0:
if blocks:
block = blocks.get(block_index, None)
if block is None:
block = blocks.get(r_index, None)
if block is not None:
block = block.T
if block is None:
block = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j])
log.warning("Hamiltonian block for hopping {} not found, r_cut may be too big for input R.".format(block_index))

assert block.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j])
if isinstance(block, torch.Tensor):
block = block.cpu().detach().numpy()
if overlap_blocks:
block_s = overlap_blocks.get(block_index, 0)
if block_s == 0:
block_s = overlap_blocks.get(block_index, None)
if block_s is None:
block_s = overlap_blocks.get(r_index, None)
if block_s is not None:
block_s = block_s.T
if block_s is None:
block_s = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j])
log.warning("Overlap block for hopping {} not found, r_cut may be too big for input R.".format(block_index))

assert block_s.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j])

if isinstance(block_s, torch.Tensor):
block_s = block_s.cpu().detach().numpy()

basis_i_list = idp.basis[symbol_i]
basis_j_list = idp.basis[symbol_j]
if Hamiltonian_blocks:
if blocks:
hopping_out = np.zeros(idp.reduced_matrix_element)
if overlap_blocks:
overlap_out = np.zeros(idp.reduced_matrix_element)
Expand All @@ -100,24 +123,112 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal
if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j):
pair_ij = full_basis_i + "-" + full_basis_j
feature_slice = idp.orbpair_maps[pair_ij]
if Hamiltonian_blocks:
if blocks:
block_ij = block[slice_i, slice_j]
hopping_out[feature_slice] = block_ij.flatten()
if overlap_blocks:
block_s_ij = block_s[slice_i, slice_j]
overlap_out[feature_slice] = block_s_ij.flatten()

if Hamiltonian_blocks:
if blocks:
edge_ham.append(hopping_out)
if overlap_blocks:
edge_overlap.append(overlap_out)

if Hamiltonian_blocks:
if blocks:
data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype())
data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype())
if overlap_blocks:
data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype())

def feature_to_block(data, idp):
idp.get_orbital_maps()
idp.get_orbpair_maps()

has_block = False
if data.get(_keys.NODE_FEATURES_KEY, None) is not None:
node_features = data[_keys.NODE_FEATURES_KEY]
edge_features = data[_keys.EDGE_FEATURES_KEY]
has_block = True
blocks = {}

idp.get_orbital_maps()
idp.get_orbpair_maps()

if has_block:
# get node blocks from node_features
for atom, onsite in enumerate(node_features):
symbol = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom].reshape(-1))]
basis_list = idp.basis[symbol]
block = torch.zeros((idp.norbs[symbol], idp.norbs[symbol]), device=node_features.device, dtype=node_features.dtype)

for index, basis_i in enumerate(basis_list):
f_basis_i = idp.basis_to_full_basis[symbol].get(basis_i)
slice_i = idp.orbital_maps[symbol][basis_i]
li = anglrMId[re.findall(r"[a-zA-Z]+", basis_i)[0]]
for basis_j in basis_list[index:]:
f_basis_j = idp.basis_to_full_basis[symbol].get(basis_j)
lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]]
slice_j = idp.orbital_maps[symbol][basis_j]
pair_ij = f_basis_i + "-" + f_basis_j
feature_slice = idp.orbpair_maps[pair_ij]
block_ij = onsite[feature_slice].reshape(2*li+1, 2*lj+1)
block[slice_i, slice_j] = block_ij
if slice_i != slice_j:
block[slice_j, slice_i] = block_ij.T

block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0]))))
blocks[block_index] = block

# get edge blocks from edge_features
edge_index = data[_keys.EDGE_INDEX_KEY]
edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
for edge, hopping in enumerate(edge_features):
atom_i, atom_j, R_shift = edge_index[0][edge], edge_index[1][edge], edge_cell_shift[edge]
symbol_i = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom_i].reshape(-1))]
symbol_j = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom_j].reshape(-1))]
block = torch.zeros((idp.norbs[symbol_i], idp.norbs[symbol_j]), device=edge_features.device, dtype=edge_features.dtype)

for index, f_basis_i in enumerate(idp.full_basis):
basis_i = idp.full_basis_to_basis[symbol_i].get(f_basis_i)
if basis_i is None:
continue
li = anglrMId[re.findall(r"[a-zA-Z]+", basis_i)[0]]
slice_i = idp.orbital_maps[symbol_i][basis_i]
for f_basis_j in idp.full_basis[index:]:
basis_j = idp.full_basis_to_basis[symbol_j].get(f_basis_j)
if basis_j is None:
continue
lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]]
slice_j = idp.orbital_maps[symbol_j][basis_j]
pair_ij = f_basis_i + "-" + f_basis_j
feature_slice = idp.orbpair_maps[pair_ij]
block_ij = hopping[feature_slice].reshape(2*li+1, 2*lj+1)
if f_basis_i == f_basis_j:
block[slice_i, slice_j] = 0.5 * block_ij
else:
block[slice_i, slice_j] = block_ij

block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift))))
if atom_i < atom_j:
if blocks.get(block_index, None) is None:
blocks[block_index] = block
else:
blocks[block_index] += block
elif atom_i == atom_j:
r_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(-R_shift))))
if blocks.get(r_index, None) is None:
blocks[block_index] = block
else:
blocks[r_index] += block.T
else:
block_index = '_'.join(map(str, map(int, [atom_j+1, atom_i+1] + list(-R_shift))))
if blocks.get(block_index, None) is None:
blocks[block_index] = block.T
else:
blocks[block_index] += block.T
return blocks


def openmx_to_deeptb(data, idp, openmx_hpath):
# Hamiltonian_blocks should be a h5 group in the current version
Expand Down
2 changes: 1 addition & 1 deletion dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def build_model(
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)

model.to(model.device)
return model


Expand Down
2 changes: 1 addition & 1 deletion dptb/postprocess/bandstructure/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
structase = data
data = AtomicData.from_ase(structase, **AtomicData_options)
elif isinstance(data, AtomicData):
structase = data.to_ase()
structase = data.to("cpu").to_ase()
data = data


Expand Down
48 changes: 48 additions & 0 deletions dptb/postprocess/write_ham.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
from dptb.utils.tools import j_must_have
from ase.io import read
import ase
from typing import Union
import matplotlib.pyplot as plt
import torch
from typing import Optional
import matplotlib
import logging
from dptb.data import AtomicData, AtomicDataDict
from dptb.data.interfaces.ham_to_feature import feature_to_block

log = logging.getLogger(__name__)

def write_ham(
data: Union[AtomicData, ase.Atoms, str],
model: torch.nn.Module,
AtomicData_options: dict={},
device: Union[str, torch.device]=None
):

model.eval()
if isinstance(device, str):
device = torch.device(device)
# get the AtomicData structure and the ase structure
if isinstance(data, str):
structase = read(data)
data = AtomicData.from_ase(structase, **AtomicData_options)
elif isinstance(data, ase.Atoms):
structase = data
data = AtomicData.from_ase(structase, **AtomicData_options)
elif isinstance(data, AtomicData):
structase = data.to("cpu").to_ase()
data = data

data = AtomicData.to_AtomicDataDict(data.to(device))
data = model.idp(data)

# set the kpoint of the AtomicData
data = model(data)
block = feature_to_block(data=data, idp=model.idp)

return block




Binary file modified dptb/tests/data/silicon_1nn/nnsk.ep500.pth
Binary file not shown.
Loading

0 comments on commit 48d89f8

Please sign in to comment.