From eaac837fbd86c190d86d123ba77ac4bcd9883f3a Mon Sep 17 00:00:00 2001 From: zhanghao Date: Fri, 12 Apr 2024 14:44:22 +0800 Subject: [PATCH 1/2] stack changes --- dptb/data/dataset/_abacus_dataset.py | 4 +- dptb/data/dataset/_abacus_dataset_mem.py | 4 +- dptb/data/dataset/_default_dataset.py | 4 +- dptb/data/interfaces/ham_to_feature.py | 130 ++++++++++++++++++++--- dptb/nn/build.py | 2 +- dptb/postprocess/bandstructure/band.py | 2 +- dptb/postprocess/write_ham.py | 48 +++++++++ dptb/utils/constants.py | 2 +- 8 files changed, 173 insertions(+), 23 deletions(-) create mode 100644 dptb/postprocess/write_ham.py diff --git a/dptb/data/dataset/_abacus_dataset.py b/dptb/data/dataset/_abacus_dataset.py index bcdb1c62..19911859 100644 --- a/dptb/data/dataset/_abacus_dataset.py +++ b/dptb/data/dataset/_abacus_dataset.py @@ -14,7 +14,7 @@ from ..transforms import TypeMapper, OrbitalMapper from ._base_datasets import AtomicDataset, AtomicInMemoryDataset #from dptb.nn.hamiltonian import E3Hamiltonian -from dptb.data.interfaces.ham_to_feature import ham_block_to_feature +from dptb.data.interfaces.ham_to_feature import block_to_feature orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"} @@ -32,7 +32,7 @@ def _abacus_h5_reader(h5file_path, AtomicData_options): basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)] idp = OrbitalMapper(basis) # e3 = E3Hamiltonian(idp=idp, decompose=True) - ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) + block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) # with torch.no_grad(): # atomic_data = e3(atomic_data.to_dict()) # atomic_data = AtomicData.from_dict(atomic_data) diff --git a/dptb/data/dataset/_abacus_dataset_mem.py b/dptb/data/dataset/_abacus_dataset_mem.py index de4263ae..7957e4a8 100644 --- a/dptb/data/dataset/_abacus_dataset_mem.py +++ b/dptb/data/dataset/_abacus_dataset_mem.py @@ -13,7 +13,7 @@ from ..transforms import TypeMapper, OrbitalMapper from ._base_datasets import AtomicInMemoryDataset from dptb.nn.hamiltonian import E3Hamiltonian -from dptb.data.interfaces.ham_to_feature import ham_block_to_feature +from dptb.data.interfaces.ham_to_feature import block_to_feature from dptb.data.interfaces.abacus import recursive_parse orbitalLId = {0:"s", 1:"p", 2:"d", 3:"f"} @@ -32,7 +32,7 @@ def _abacus_h5_reader(h5file_path, AtomicData_options): basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)] idp = OrbitalMapper(basis) # e3 = E3Hamiltonian(idp=idp, decompose=True) - ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) + block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False)) # with torch.no_grad(): # atomic_data = e3(atomic_data.to_dict()) # atomic_data = AtomicData.from_dict(atomic_data) diff --git a/dptb/data/dataset/_default_dataset.py b/dptb/data/dataset/_default_dataset.py index 88180962..52ca442f 100644 --- a/dptb/data/dataset/_default_dataset.py +++ b/dptb/data/dataset/_default_dataset.py @@ -16,7 +16,7 @@ from ..transforms import TypeMapper, OrbitalMapper from ._base_datasets import AtomicDataset, AtomicInMemoryDataset #from dptb.nn.hamiltonian import E3Hamiltonian -from dptb.data.interfaces.ham_to_feature import ham_block_to_feature +from dptb.data.interfaces.ham_to_feature import block_to_feature from dptb.utils.tools import j_loader from dptb.data.AtomicDataDict import with_edge_vectors from dptb.nn.hamiltonian import E3Hamiltonian @@ -201,7 +201,7 @@ def toAtomicDataList(self, idp: TypeMapper = None): overlaps = False # e3 = E3Hamiltonian(idp=idp, decompose=True) if features != False or overlaps != False: - ham_block_to_feature(atomic_data, idp, features, overlaps) + block_to_feature(atomic_data, idp, features, overlaps) if not hasattr(atomic_data, AtomicDataDict.EDGE_FEATURES_KEY): # TODO: initialize the edge and node feature tempretely, there should be a better way. diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index cedb2aec..aa2e1ffa 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -10,11 +10,11 @@ log = logging.getLogger(__name__) -def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=False): +def block_to_feature(data, idp, blocks=False, overlap_blocks=False): # Hamiltonian_blocks should be a h5 group in the current version - assert Hamiltonian_blocks != False or overlap_blocks!=False, "Both Hamiltonian and overlap blocks are not provided." + assert blocks != False or overlap_blocks!=False, "Both feature block and overlap blocks are not provided." - if Hamiltonian_blocks: + if blocks: onsite_ham = [] edge_ham = [] if overlap_blocks: @@ -26,14 +26,16 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY] # onsite features - if Hamiltonian_blocks: + if blocks: for atom in range(len(atomic_numbers)): block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0])))) try: - block = Hamiltonian_blocks[block_index] + block = blocks[block_index] except: raise IndexError("Hamiltonian block for onsite not found, check Hamiltonian file.") + if isinstance(block, torch.Tensor): + block = block.cpu().detach().numpy() symbol = ase.data.chemical_symbols[atomic_numbers[atom]] basis_list = idp.basis[symbol] onsite_out = np.zeros(idp.reduced_matrix_element) @@ -60,6 +62,7 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal for atom_i, atom_j, R_shift in zip(edge_index[0], edge_index[1], edge_cell_shift): block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift)))) + r_index = '_'.join(map(str, map(int, [atom_j+1, atom_i+1] + list(-R_shift)))) symbol_i = ase.data.chemical_symbols[atomic_numbers[atom_i]] symbol_j = ase.data.chemical_symbols[atomic_numbers[atom_j]] @@ -69,24 +72,37 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal # block_s = overlap_blocks[block_index] # except: # raise IndexError("Hamiltonian block for hopping not found, r_cut may be too big for input R.") - if Hamiltonian_blocks: - block = Hamiltonian_blocks.get(block_index, 0) - if block == 0: + if blocks: + block = blocks.get(block_index, None) + if block is None: + block = blocks.get(r_index, None) + if block is not None: + block = block.T + if block is None: block = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j]) log.warning("Hamiltonian block for hopping {} not found, r_cut may be too big for input R.".format(block_index)) assert block.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j]) + if isinstance(block, torch.Tensor): + block = block.cpu().detach().numpy() if overlap_blocks: - block_s = overlap_blocks.get(block_index, 0) - if block_s == 0: + block_s = overlap_blocks.get(block_index, None) + if block_s is None: + block_s = overlap_blocks.get(r_index, None) + if block_s is not None: + block_s = block_s.T + if block_s is None: block_s = torch.zeros(idp.norbs[symbol_i], idp.norbs[symbol_j]) log.warning("Overlap block for hopping {} not found, r_cut may be too big for input R.".format(block_index)) assert block_s.shape == (idp.norbs[symbol_i], idp.norbs[symbol_j]) + + if isinstance(block_s, torch.Tensor): + block_s = block_s.cpu().detach().numpy() basis_i_list = idp.basis[symbol_i] basis_j_list = idp.basis[symbol_j] - if Hamiltonian_blocks: + if blocks: hopping_out = np.zeros(idp.reduced_matrix_element) if overlap_blocks: overlap_out = np.zeros(idp.reduced_matrix_element) @@ -100,24 +116,110 @@ def ham_block_to_feature(data, idp, Hamiltonian_blocks=False, overlap_blocks=Fal if idp.full_basis.index(full_basis_i) <= idp.full_basis.index(full_basis_j): pair_ij = full_basis_i + "-" + full_basis_j feature_slice = idp.orbpair_maps[pair_ij] - if Hamiltonian_blocks: + if blocks: block_ij = block[slice_i, slice_j] hopping_out[feature_slice] = block_ij.flatten() if overlap_blocks: block_s_ij = block_s[slice_i, slice_j] overlap_out[feature_slice] = block_s_ij.flatten() - if Hamiltonian_blocks: + if blocks: edge_ham.append(hopping_out) if overlap_blocks: edge_overlap.append(overlap_out) - if Hamiltonian_blocks: + if blocks: data[_keys.NODE_FEATURES_KEY] = torch.as_tensor(np.array(onsite_ham), dtype=torch.get_default_dtype()) data[_keys.EDGE_FEATURES_KEY] = torch.as_tensor(np.array(edge_ham), dtype=torch.get_default_dtype()) if overlap_blocks: data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) +def feature_to_block(data, idp): + idp.get_orbital_maps() + idp.get_orbpair_maps() + + has_block = False + if data.get(_keys.NODE_FEATURES_KEY, None) is not None: + node_features = data[_keys.NODE_FEATURES_KEY] + edge_features = data[_keys.EDGE_FEATURES_KEY] + has_block = True + blocks = {} + + idp.get_orbital_maps() + idp.get_orbpair_maps() + + if has_block: + # get node blocks from node_features + for atom, onsite in enumerate(node_features): + symbol = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom].reshape(-1))] + basis_list = idp.basis[symbol] + block = torch.zeros((idp.norbs[symbol], idp.norbs[symbol]), device=node_features.device, dtype=node_features.dtype) + + for index, basis_i in enumerate(basis_list): + slice_i = idp.orbital_maps[symbol][basis_i] + li = anglrMId[re.findall(r"[a-zA-Z]+", basis_i)[0]] + for basis_j in basis_list[index:]: + lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]] + slice_j = idp.orbital_maps[symbol][basis_j] + pair_ij = basis_i + "-" + basis_j + feature_slice = idp.orbpair_maps[pair_ij] + block_ij = onsite[feature_slice].reshape(2*li+1, 2*lj+1) + block[slice_i, slice_j] = block_ij + if slice_i != slice_j: + block[slice_j, slice_i] = block_ij.T + + block_index = '_'.join(map(str, map(int, [atom+1, atom+1] + list([0, 0, 0])))) + blocks[block_index] = block + + # get edge blocks from edge_features + edge_index = data[_keys.EDGE_INDEX_KEY] + edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY] + for edge, hopping in enumerate(edge_features): + atom_i, atom_j, R_shift = edge_index[0][edge], edge_index[1][edge], edge_cell_shift[edge] + symbol_i = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom_i].reshape(-1))] + symbol_j = ase.data.chemical_symbols[idp.untransform(data[_keys.ATOM_TYPE_KEY][atom_j].reshape(-1))] + block = torch.zeros((idp.norbs[symbol_i], idp.norbs[symbol_j]), device=edge_features.device, dtype=edge_features.dtype) + + for index, f_basis_i in enumerate(idp.full_basis): + basis_i = idp.full_basis_to_basis[symbol_i].get(f_basis_i) + if basis_i is None: + continue + li = anglrMId[re.findall(r"[a-zA-Z]+", basis_i)[0]] + slice_i = idp.orbital_maps[symbol_i][basis_i] + for f_basis_j in idp.full_basis[index:]: + basis_j = idp.full_basis_to_basis[symbol_j].get(f_basis_j) + if basis_j is None: + continue + lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]] + slice_j = idp.orbital_maps[symbol_j][basis_j] + pair_ij = basis_i + "-" + basis_j + feature_slice = idp.orbpair_maps[pair_ij] + block_ij = hopping[feature_slice].reshape(2*li+1, 2*lj+1) + if f_basis_i == f_basis_j: + block[slice_i, slice_j] = 0.5 * block_ij + else: + block[slice_i, slice_j] = block_ij + + block_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(R_shift)))) + if atom_i < atom_j: + if blocks.get(block_index, None) is None: + blocks[block_index] = block + else: + blocks[block_index] += block + elif atom_i == atom_j: + r_index = '_'.join(map(str, map(int, [atom_i+1, atom_j+1] + list(-R_shift)))) + if blocks.get(r_index, None) is None: + blocks[block_index] = block + else: + blocks[r_index] += block.T + else: + block_index = '_'.join(map(str, map(int, [atom_j+1, atom_i+1] + list(-R_shift)))) + if blocks.get(block_index, None) is None: + blocks[block_index] = block.T + else: + blocks[block_index] += block.T + return blocks + def openmx_to_deeptb(data, idp, openmx_hpath): # Hamiltonian_blocks should be a h5 group in the current version diff --git a/dptb/nn/build.py b/dptb/nn/build.py index ed1c057c..9de03cda 100644 --- a/dptb/nn/build.py +++ b/dptb/nn/build.py @@ -145,7 +145,7 @@ def build_model( log.warning(f"The model options {k} is not defined in input model_options, set to {v}.") else: deep_dict_difference(k, v, model_options) - + model.to(model.device) return model diff --git a/dptb/postprocess/bandstructure/band.py b/dptb/postprocess/bandstructure/band.py index 79d1d33f..fcf21ec0 100644 --- a/dptb/postprocess/bandstructure/band.py +++ b/dptb/postprocess/bandstructure/band.py @@ -210,7 +210,7 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict, structase = data data = AtomicData.from_ase(structase, **AtomicData_options) elif isinstance(data, AtomicData): - structase = data.to_ase() + structase = data.to("cpu").to_ase() data = data diff --git a/dptb/postprocess/write_ham.py b/dptb/postprocess/write_ham.py new file mode 100644 index 00000000..e250968d --- /dev/null +++ b/dptb/postprocess/write_ham.py @@ -0,0 +1,48 @@ +import numpy as np +from dptb.utils.tools import j_must_have +from ase.io import read +import ase +from typing import Union +import matplotlib.pyplot as plt +import torch +from typing import Optional +import matplotlib +import logging +from dptb.data import AtomicData, AtomicDataDict +from dptb.data.interfaces.ham_to_feature import feature_to_block + +log = logging.getLogger(__name__) + +def write_ham( + data: Union[AtomicData, ase.Atoms, str], + model: torch.nn.Module, + AtomicData_options: dict={}, + device: Union[str, torch.device]=None + ): + + model.eval() + if isinstance(device, str): + device = torch.device(device) + # get the AtomicData structure and the ase structure + if isinstance(data, str): + structase = read(data) + data = AtomicData.from_ase(structase, **AtomicData_options) + elif isinstance(data, ase.Atoms): + structase = data + data = AtomicData.from_ase(structase, **AtomicData_options) + elif isinstance(data, AtomicData): + structase = data.to("cpu").to_ase() + data = data + + data = AtomicData.to_AtomicDataDict(data.to(device)) + data = model.idp(data) + + # set the kpoint of the AtomicData + data = model(data) + block = feature_to_block(data=data, idp=model.idp) + + return block + + + + diff --git a/dptb/utils/constants.py b/dptb/utils/constants.py index 4ebfc9ec..4e3e86f4 100644 --- a/dptb/utils/constants.py +++ b/dptb/utils/constants.py @@ -71,7 +71,7 @@ ABACUS2DeePTB[2][[1, 3]] *= -1 ABACUS2DeePTB[3][[0, 6, 2, 4]] *= -1 ABACUS2DeePTB[4][[1, 7, 3, 5]] *= -1 -ABACUS2DeePTB[5][[0, 8, 2, 6, 4]] *= -1 +ABACUS2DeePTB[5][[0, 10, 8, 2, 6, 4]] *= -1 OPENMX2DeePTB = { "s": torch.eye(1).double(), From 48f020ba594b88022657ed1af83a7ca835abbde2 Mon Sep 17 00:00:00 2001 From: zhanghao Date: Mon, 15 Apr 2024 16:30:58 +0800 Subject: [PATCH 2/2] fix test ham to feature --- dptb/data/interfaces/ham_to_feature.py | 13 +- dptb/tests/data/silicon_1nn/nnsk.ep500.pth | Bin 5419 -> 5419 bytes dptb/tests/test_block_to_feature.py | 165 +++++++++++++++++++++ 3 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 dptb/tests/test_block_to_feature.py diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index aa2e1ffa..9540f22d 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -7,6 +7,7 @@ import h5py import logging from dptb.utils.constants import anglrMId, OPENMX2DeePTB +from dptb.data import AtomicData, AtomicDataDict log = logging.getLogger(__name__) @@ -23,6 +24,12 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False): idp.get_orbital_maps() idp.get_orbpair_maps() + if isinstance(data, AtomicData): + if not hasattr(data, _keys.ATOMIC_NUMBERS_KEY): + setattr(data, _keys.ATOMIC_NUMBERS_KEY, idp.untransform(data[_keys.ATOM_TYPE_KEY])) + if isinstance(data, dict): + if not data.get(_keys.ATOMIC_NUMBERS_KEY, None): + data[_keys.ATOMIC_NUMBERS_KEY] = idp.untransform(data[_keys.ATOM_TYPE_KEY]) atomic_numbers = data[_keys.ATOMIC_NUMBERS_KEY] # onsite features @@ -156,12 +163,14 @@ def feature_to_block(data, idp): block = torch.zeros((idp.norbs[symbol], idp.norbs[symbol]), device=node_features.device, dtype=node_features.dtype) for index, basis_i in enumerate(basis_list): + f_basis_i = idp.basis_to_full_basis[symbol].get(basis_i) slice_i = idp.orbital_maps[symbol][basis_i] li = anglrMId[re.findall(r"[a-zA-Z]+", basis_i)[0]] for basis_j in basis_list[index:]: + f_basis_j = idp.basis_to_full_basis[symbol].get(basis_j) lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]] slice_j = idp.orbital_maps[symbol][basis_j] - pair_ij = basis_i + "-" + basis_j + pair_ij = f_basis_i + "-" + f_basis_j feature_slice = idp.orbpair_maps[pair_ij] block_ij = onsite[feature_slice].reshape(2*li+1, 2*lj+1) block[slice_i, slice_j] = block_ij @@ -192,7 +201,7 @@ def feature_to_block(data, idp): continue lj = anglrMId[re.findall(r"[a-zA-Z]+", basis_j)[0]] slice_j = idp.orbital_maps[symbol_j][basis_j] - pair_ij = basis_i + "-" + basis_j + pair_ij = f_basis_i + "-" + f_basis_j feature_slice = idp.orbpair_maps[pair_ij] block_ij = hopping[feature_slice].reshape(2*li+1, 2*lj+1) if f_basis_i == f_basis_j: diff --git a/dptb/tests/data/silicon_1nn/nnsk.ep500.pth b/dptb/tests/data/silicon_1nn/nnsk.ep500.pth index b72a49313258e90c5402f3e89f59466cfd647871..4e5b74d04e07b7eb0fefa961fab9b41854f05259 100644 GIT binary patch delta 2119 zcmY*aS#%Ud6rD-bkWo}51_VST0+T>OAPJjF1c9hcSO%9y9h=H@&2&r8bl|2$S#F>%sg-S5@A@4Z{^-5zckZaFwD zK31=d#bV6O;Gz)Y%hTeG#fA_UC-t+1xsNcLp)te+y+(X)urV=>OOo=VgJF_-@i1Aw zm@Sw>nSPEqY0`c>Lvg7D&v*f*N<5PfdK8!Gb>c|YaS893k7@GD?Wev&JX~JV6K4r8 zqtwLo^0D~jVpE7Kl6t)`7J|G%JvU=vM(|LPd1FmOysccQnY?_i z7L&!2Jj?mSM0+XF&l6aVZ?KF-&A<^!=Y^P`v|426bihN01kRCso}jaQqGn20SBM4D z$bDiLurS0Tt6l_604#s-JpyKtkmP; zOgmPM3_*|9z;WGrEvCh_gjO?z>$L{0MvD*OhLRQK)HqhB&?|dnlASYo*H|K_wZf8^ z#n?QowItZODzNnxu#^JZz;R;=Hw7VXmSh@vak(bp1h>d!dQunl#jr8Nt+E|4f=qCM z+q5{-&lN6<-4K(;I|$WX`Hdy(5I7c1JReO(959i-qV*%2{Ka zddO6F6PpiHG!y5bqM7DqReY8smx86HV!$&YyXNbo}L}6D?_x|yxf8K?aFPl-^WC)P+_GsNS8Px>5+a!IbDX=k%yn+n{P)+bAH-LnA`qV|*u zvJ1BGI)+fDr~g-`NIPZ049}$TY`^17L}YJI_aFb9{H1YdwyaKQG!-Ze z&!zEvd0*Y68T)$wj8_J_&t(+*OLCmWoKnxoxPcgo7s|iZH5Fe}4H^z%JgVfSG+rKI zgaoo`CAhCdf*WbD&8bf!$-t|x(ALX%?K@qw%?%77!IHe?Y#RO9_o)p&o* z@ky$PPbHck6bg|}MvMB{sMy8$T*V%(iao~hMG9Z4y^MZoLY;)ld^Hjo{cq>B<7*ZA zO;zZ(9N(pIJg|0!_&&MPs?gb)!w;4q&=O?*D1py#)tL)%f>=jJZZ-&lwsl(Fn3Ds0 z51kU1daAzsTVm>jS+PQaZEmApS9|;1NGRvFuUg&|E3Z#99p09hUROVIzC1iG6>mBG J@3_|5{{SLDo?`$2 delta 2018 zcmY*aS#%Ud6rE0@lSD)#A!rhiMQ{=z8A!tB5{N?7HVA`CqmE5wre?aOXS#1!bqruL zuHYy}wB5i3+{L({;J)vAgrgsR_~i$m9?$VZ{PkE}&E!n_tNXor_r3e-zSrZOqHa6fw2KVq`Iwf_&Jd1a!c2x{AJerO@wv%FOByp0@}r00{OHBO z1@a}^iB`(AGsVd)9WgT$ZCY&z?|5jJcv-&Pp_r+~g*9i{vkC86idpj09-*#999&q| z>V%h3n#JtkxYAad<6~|@YY@hwSJ0_rXAI06I;Cih!Zb{H#?YC$gW@7pk*7i$1z~V8 z-^YTeSj)x5gZ+p1htC0-Kvt5FuWPcTfR2bT<%FR#jfIK9{=bfgPmrr=Ru%<2YuW}D zt1($CDX_dtvgj&%+7wZQPq}@J}2_qqy(@_UYBygS-3It1o-)q`SJwAG+ zk-Nk!qR+=Nqd@{wGe1HshUJV&DvXR!J?BzjMM(G@kxTf^ED0|va;#JvWer)AITqO- z8CR)Ud9g0d=~W%NM099sEMT=w?U$;G=V|8(z=EBnmd-fgRL`0jj?02WwXLm}ccjpt z(9RYo4BK(cVqSNMD;l*r_^ozsX-y&_D`wnwNuFXLBVIB_<2uVGd{bm>#1)z%;*+>? zLI-QrW{!0m)tIWNjcUyv464m)jjHUyRb?a0c{Rs1DX5YYlMOnZ_gX{bGzJWbS&A*e zkRid=SAktu0oxFPZRAL$u*vgrz3iu65cey(M{t8oc4I;l^~JE+hbHL=5i-$QY*7`a zofUE?1=Heo(IiFOO)VcF&|qaT`^7_km|ByOED~} zRyZ=E>W+@6$}9&>Ar~oN)X~XE9{EXOJuo7&W)<1eZ56iLA{>jOkU~+;g-~9Iqcgys z1l0>Cft!{qnyG6Ld)SGN`>lq*Y+HS2h-T= zIFi9jLPMKhs*Lla<)r} zSUCIE%eH2S!#$S_TNID0%2#7Oaj> zuX{X=CxR{U1x-)-cq)=mv>^5e`{Fa|4y5sPFdlE6|IE-oN=28$Hluj9ELADSl{$LH z_KFKWy|oAkCzjY zT{Lc%%@|(s@oHqW?7S|zP#_$y$#uD^Yk3ZbWw%0~C&=M-V{>q*etOGz8gE2lJHtxC zqdwjYey(qAe~Zd>6NbnguIEf@Wf|U1<4B-1w3OZam8 z|I3Nc@He<5iLavAudC92!|`nj-+9KEkM9#krB}Uq{9uUo7@{>lO5D?(>COB2i5SN! zs+>0Fd0y8BHD3SA-@Epm5Lb4hA^4+l){Nd*vB<*V#cXHi8G8Sx{(tWhl VI4|BXaXuJtN+~ms{@1j)_J5|9c3A)b diff --git a/dptb/tests/test_block_to_feature.py b/dptb/tests/test_block_to_feature.py new file mode 100644 index 00000000..c392ecf7 --- /dev/null +++ b/dptb/tests/test_block_to_feature.py @@ -0,0 +1,165 @@ +import pytest +import os +import torch +from dptb.nn.nnsk import NNSK +from dptb.data.transforms import OrbitalMapper +from dptb.data.build import build_dataset +from pathlib import Path +from dptb.data import AtomicDataset, DataLoader, AtomicDataDict, AtomicData +import numpy as np +from dptb.nn.hamiltonian import SKHamiltonian +from dptb.data.interfaces.ham_to_feature import block_to_feature, feature_to_block +from dptb.utils.constants import anglrMId +from e3nn.o3 import wigner_3j, Irrep, xyz_to_angles, Irrep + +rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data") + +class TestBlock2Feature: + common_options = { + "basis": { + "B": ["2s", "2p"], + "N": ["2s", "2p"] + }, + "device": "cpu", + "dtype": "float32", + "overlap": False, + "seed": 3982377700 + } + model_options = { + "nnsk": { + "onsite": { + "method": "none" + }, + "hopping": { + "method": "powerlaw", + "rs": 2.6, + "w": 0.35 + }, + "freeze": False, + "std": 0.1, + "push": None} + } + data_options = { + "train": { + "root": f"{rootdir}/hBN/dataset", + "prefix": "kpath", + "get_eigenvalues": True + } + } + + train_datasets = build_dataset(**data_options["train"], **common_options) + train_loader = DataLoader(dataset=train_datasets, batch_size=1, shuffle=True) + + batch = next(iter(train_loader)) + batch = AtomicData.to_AtomicDataDict(batch) + idp_sk = OrbitalMapper(basis=common_options['basis'], method="sktb") + idp = OrbitalMapper(basis=common_options['basis'], method="e3tb") + + sk2irs = { + 's-s': torch.tensor([[1.]]), + 's-p': torch.tensor([[1.]]), + 's-d': torch.tensor([[1.]]), + 'p-s': torch.tensor([[1.]]), + 'p-p': torch.tensor([ + [3**0.5/3,2/3*3**0.5],[6**0.5/3,-6**0.5/3] + ]), + 'p-d':torch.tensor([ + [(2/5)**0.5,(6/5)**0.5],[(3/5)**0.5,-2/5**0.5] + ]), + 'd-s':torch.tensor([[1.]]), + 'd-p':torch.tensor([ + [(2/5)**0.5,(6/5)**0.5], + [(3/5)**0.5,-2/5**0.5] + ]), + 'd-d':torch.tensor([ + [5**0.5/5, 2*5**0.5/5, 2*5**0.5/5], + [2*(1/14)**0.5,2*(1/14)**0.5,-4*(1/14)**0.5], + [3*(2/35)**0.5,-4*(2/35)**0.5,(2/35)**0.5] + ]) + } + + def test_transform_onsiteblocks_none(self): + hamiltonian = SKHamiltonian(idp_sk=self.idp_sk, onsite=True) + nnsk = NNSK(**self.common_options, **self.model_options["nnsk"],transform=False) + data = nnsk(self.batch) + data = hamiltonian(data) + + block = feature_to_block(data, nnsk.idp) + block_to_feature(data, nnsk.idp, blocks=block) + assert data[AtomicDataDict.NODE_FEATURES_KEY].shape == torch.Size([2, 13]) + + expected_onsite = torch.tensor([[-18.4200038910, 0.0000000000, 0.0000000000, 0.0000000000, + -7.2373123169, -0.0000000000, -0.0000000000, -0.0000000000, + -7.2373123169, -0.0000000000, -0.0000000000, -0.0000000000, + -7.2373123169], + [ -9.3830089569, 0.0000000000, 0.0000000000, 0.0000000000, + -3.7138016224, -0.0000000000, -0.0000000000, -0.0000000000, + -3.7138016224, -0.0000000000, -0.0000000000, -0.0000000000, + -3.7138016224]]) + assert torch.allclose(data[AtomicDataDict.NODE_FEATURES_KEY], expected_onsite) + + def test_transform_hoppingblocks(self): + hamiltonian = SKHamiltonian(idp_sk=self.idp_sk, onsite=True) + nnsk = NNSK(**self.common_options, **self.model_options["nnsk"],transform=False) + nnsk.hopping_param.data = torch.tensor([[[-0.0299384445, -0.0187778082], + [ 0.1915897578, 0.0690195262], + [-0.2321701497, -0.1196410209], + [ 0.0197028164, -0.1177332327]], + + [[ 0.0550494418, -0.0191540867], + [-0.1395172030, 0.0475118719], + [-0.0351739973, 0.0052711815], + [ 0.0192712545, -0.1666133553]], + + [[ 0.0550494418, -0.0191540867], + [ 0.0586687513, 0.0158295482], + [-0.0351739973, 0.0052711815], + [ 0.0192712545, -0.1666133553]], + + [[ 0.1311892122, -0.0209838580], + [ 0.0781731308, 0.0989692509], + [ 0.0414713360, -0.1508950591], + [ 0.2036036998, 0.0131590459]]]) + data = nnsk(self.batch) + data = hamiltonian(data) + + block = feature_to_block(data, nnsk.idp) + block_to_feature(data, nnsk.idp, blocks=block) + + assert data[AtomicDataDict.EDGE_FEATURES_KEY].shape == torch.Size([18, 13]) + + + expected_selected_hopblock = torch.tensor([[ 5.3185172379e-02, -4.6635824091e-09, 1.3500485174e-09, + 3.0885510147e-02, 8.2756355405e-02, 4.3990724937e-16, + 1.0063905265e-08, 4.3990724937e-16, 8.2756355405e-02, + -2.9133742085e-09, 1.0063905265e-08, -2.9133742085e-09, + 1.6106124967e-02], + [ 6.2371429056e-02, -6.6437192261e-02, 2.9040618799e-09, + 2.9040618799e-09, -3.9765007794e-02, 2.7151161319e-09, + 2.7151161319e-09, 2.7151161319e-09, 2.2349609062e-02, + -1.1868149201e-16, 2.7151161319e-09, -1.1868149201e-16, + 2.2349609062e-02], + [ 5.3185172379e-02, -0.0000000000e+00, 1.3500485174e-09, + -3.0885510147e-02, 8.2756355405e-02, 0.0000000000e+00, + 0.0000000000e+00, 0.0000000000e+00, 8.2756355405e-02, + 2.9133742085e-09, 0.0000000000e+00, 2.9133742085e-09, + 1.6106124967e-02], + [ 6.2371429056e-02, 3.3218599856e-02, 2.9040618799e-09, + -5.7536296546e-02, 6.8209525198e-03, -1.3575581770e-09, + 2.6896420866e-02, -1.3575582880e-09, 2.2349609062e-02, + 2.3513595515e-09, 2.6896420866e-02, 2.3513595515e-09, + -2.4236353114e-02], + [ 6.2371429056e-02, -1.5878447890e-01, -6.9406898007e-09, + 1.1987895121e-08, -3.9765007794e-02, -2.7151161319e-09, + 4.6895234362e-09, -2.7151161319e-09, 2.2349609062e-02, + 2.0498557313e-16, 4.6895234362e-09, 2.0498557313e-16, + 2.2349609062e-02], + [-1.0692023672e-02, 5.7914875448e-02, 2.9231701504e-09, + 3.3437173814e-02, -5.7711567730e-02, -3.2524076765e-09, + -3.7203211337e-02, -3.2524074545e-09, 6.7262742668e-03, + -1.8777785993e-09, -3.7203207612e-02, -1.8777785993e-09, + -1.4753011055e-02]]) + + assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]] - expected_selected_hopblock) < 1e-6) + +