Skip to content

Commit

Permalink
update e3deeph module
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingCatty committed Nov 29, 2023
1 parent 8e3950c commit ebf326f
Show file tree
Hide file tree
Showing 20 changed files with 1,475 additions and 82 deletions.
8 changes: 4 additions & 4 deletions dptb/data/dataset/_abacus_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def get(self, idx):
for key, value in data["basis"].items():
basis[key] = [(f"{i+1}" + orbitalLId[l]) for i, l in enumerate(value)]
idp = OrbitalMapper(basis)
# e3 = E3Hamiltonian(idp=idp, decompose=True)
e3 = E3Hamiltonian(idp=idp, decompose=True)
ham_block_to_feature(atomic_data, idp, data.get("hamiltonian_blocks", False), data.get("overlap_blocks", False))
# with torch.no_grad():
# atomic_data = e3(atomic_data.to_dict())
# atomic_data = AtomicData.from_dict(atomic_data)
with torch.no_grad():
atomic_data = e3(atomic_data.to_dict())
atomic_data = AtomicData.from_dict(atomic_data)
if data.get("eigenvalue") and data.get("kpoint"):
atomic_data[AtomicDataDict.KPOINT_KEY] = torch.as_tensor(data["kpoint"][:], dtype=torch.get_default_dtype())
atomic_data[AtomicDataDict.ENERGY_EIGENVALUE_KEY] = torch.as_tensor(data["eigenvalue"][:], dtype=torch.get_default_dtype())
Expand Down
15 changes: 11 additions & 4 deletions dptb/data/interfaces/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@ def __init__(self):
self.Us_abacus2deeptb[0] = np.eye(1)
self.Us_abacus2deeptb[1] = np.eye(3)[[2, 0, 1]] # 0, 1, -1 -> -1, 0, 1
self.Us_abacus2deeptb[2] = np.eye(5)[[4, 2, 0, 1, 3]] # 0, 1, -1, 2, -2 -> -2, -1, 0, 1, 2
self.Us_abacus2deeptb[3] = np.eye(7)[[6, 4, 2, 0, 1, 3, 5]]
self.Us_abacus2deeptb[3] = np.eye(7)[[6, 4, 2, 0, 1, 3, 5]] # -3,-2,-1,0,1,2,3

# minus_dict = {
# 1: [1, 2],
# 2: [0, 2],
# 3: [0, 2, 4, 6],
# }
# for k, v in minus_dict.items():
# self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m

minus_dict = {
1: [0, 2],
2: [1, 3],
3: [0, 2, 4, 6],
}

for k, v in minus_dict.items():
self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m

def get_U(self, l):
if l > 3:
Expand Down Expand Up @@ -236,7 +243,7 @@ def parse_matrix(matrix_path, factor, spinful=False):
site_norbits_cumsum[index_site_i] * (1 + spinful),
(site_norbits_cumsum[index_site_j] - site_norbits[index_site_j]) * (1 + spinful):
site_norbits_cumsum[index_site_j] * (1 + spinful)]
if abs(mat).max() < 1e-8:
if abs(mat).max() < 1e-10:
continue
if not spinful:
mat = U_orbital.transform(mat, orbital_types_dict[element[index_site_i]],
Expand Down
54 changes: 51 additions & 3 deletions dptb/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

import ase.data
import e3nn.o3 as o3

from dptb.data import AtomicData, AtomicDataDict

Expand Down Expand Up @@ -455,8 +456,7 @@ def __init__(

# Get the mask for mapping from full basis to atom specific basis
self.mask_to_basis = torch.zeros(len(self.type_names), self.full_basis_norb, dtype=torch.bool)
self.mask_to_erme = torch.zeros(len(self.type_names), self.edge_reduced_matrix_element, dtype=torch.bool)
self.mask_to_nrme = torch.zeros(len(self.type_names), self.node_reduced_matrix_element, dtype=torch.bool)

for ib in self.basis.keys():
ibasis = list(self.basis_to_full_basis[ib].values())
ist = 0
Expand All @@ -469,9 +469,21 @@ def __init__(

assert (self.mask_to_basis.sum(dim=1).int()-self.atom_norb).abs().sum() <= 1e-6


self.get_pair_maps()
self.get_node_maps()

self.mask_to_erme = torch.zeros(len(self.reduced_bond_types), self.edge_reduced_matrix_element, dtype=torch.bool)
self.mask_to_nrme = torch.zeros(len(self.type_names), self.node_reduced_matrix_element, dtype=torch.bool)
for ib in self.basis.keys():
for opair in self.node_maps:
self.mask_to_nrme[self.chemical_symbol_to_type[ib]][self.node_maps[opair]] = True


for ib in self.reduced_bond_to_type.keys():
for opair in self.pair_maps:
self.mask_to_erme[self.reduced_bond_to_type[ib]][self.pair_maps[opair]] = True


def get_pairtype_maps(self):
"""
The function `get_pairtype_maps` creates a mapping of orbital pair types, such as s-s, "s-p",
Expand Down Expand Up @@ -619,3 +631,39 @@ def get_orbital_maps(self):

return self.orbital_maps

def get_irreps(self, no_parity=True):
assert self.method == "e3tb", "Only support e3tb method for now."

if hasattr(self, "node_irreps") and hasattr(self, "pair_irreps"):
return self.node_maps, self.pair_irreps

if not hasattr(self, "nodetype_maps"):
self.get_nodetype_maps()

if not hasattr(self, "pairtype_maps"):
self.get_pairtype_maps()

irs = []
if no_parity:
factor = 1
else:
factor = -1
for pair, sli in self.pairtype_maps.items():
l1, l2 = anglrMId[pair[0]], anglrMId[pair[2]]
ir1 = o3.Irrep((l1, factor**l1))
ir2 = o3.Irrep((l2, factor**l2))
irs += [i for i in ir1*ir2]*int((sli.stop-sli.start)/(2*l1+1)/(2*l2+1))

self.pair_irreps = o3.Irreps(irs)

irs = []
for pair, sli in self.nodetype_maps.items():
l1, l2 = anglrMId[pair[0]], anglrMId[pair[2]]
ir1 = o3.Irrep((l1, factor**l1))
ir2 = o3.Irrep((l2, factor**l2))
irs += [i for i in ir1*ir2]*int((sli.stop-sli.start)/(2*l1+1)/(2*l2+1))

self.node_irreps = o3.Irreps(irs)
return self.node_irreps, self.pair_irreps


16 changes: 12 additions & 4 deletions dptb/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
in_field: AtomicDataDict.NODE_FEATURES_KEY,
out_field: AtomicDataDict.NODE_FEATURES_KEY,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
if_batch_normalized: bool = False,
if_batch_normalized: bool = False,
device: Union[str, torch.device] = torch.device('cpu'),
dtype: Union[str, torch.dtype] = torch.float32,
**kwargs
Expand Down Expand Up @@ -140,13 +140,18 @@ def __init__(
self.out_layer = AtomicMLP(**config[-1], in_field=out_field, out_field=out_field, if_batch_normalized=False, activation=activation, device=device, dtype=dtype)
self.out_field = out_field
self.in_field = in_field
# self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True)

def forward(self, data: AtomicDataDict.Type):
out_scale = self.out_scale(data[self.in_field])
out_shift = self.out_shift(data[self.in_field])
for layer in self.layers:
data = layer(data)
data[self.out_field] = self.activation(data[self.out_field])

return self.out_layer(data)
data = self.out_layer(data)
# data[self.out_field] = self.out_norm(data[self.out_field])
return data


class AtomicResBlock(torch.nn.Module):
Expand Down Expand Up @@ -264,13 +269,16 @@ def __init__(
self.out_layer = AtomicLinear(in_features=config[-1]['in_features'], out_features=config[-1]['out_features'], in_field=out_field, out_field=out_field, device=device, dtype=dtype)
else:
self.out_layer = AtomicMLP(**config[-1], if_batch_normalized=False, in_field=in_field, out_field=out_field, activation=activation, device=device, dtype=dtype)
# self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True)

def forward(self, data: AtomicDataDict.Type):

for layer in self.layers:
data = layer(data)
data[self.out_field] = self.activation(data[self.out_field])

return self.out_layer(data)
data = self.out_layer(data)
# data[self.out_field] = self.out_norm(data[self.out_field])
return data

class MLP(nn.Module):
def __init__(
Expand Down
16 changes: 11 additions & 5 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
self.method = prediction["hamiltonian"].get("method", "e3tb")
self.overlap = prediction["hamiltonian"].get("overlap", False)
self.soc = prediction["hamiltonian"].get("soc", False)
self.prediction = prediction

if basis is not None:
self.idp = OrbitalMapper(basis, method=self.method)
Expand All @@ -87,10 +88,10 @@ def __init__(


# initialize the embedding layer
self.embedding = Embedding(**embedding, dtype=dtype, device=device, n_atom=len(self.basis.keys()))
self.embedding = Embedding(**embedding, dtype=dtype, device=device, idp=self.idp, n_atom=len(self.basis.keys()))

# initialize the prediction layer
if prediction["method"] == "linear":
if prediction.get("method") == "linear":

self.node_prediction_h = AtomicLinear(
in_features=self.embedding.out_node_dim,
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
device=device
)

elif prediction["method"] == "nn":
elif prediction.get("method") == "nn":
prediction["neurons"] = [self.embedding.out_node_dim] + prediction["neurons"] + [self.idp.node_reduced_matrix_element]
prediction["config"] = get_neuron_config(prediction["neurons"])

Expand Down Expand Up @@ -158,6 +159,8 @@ def __init__(
device=device,
dtype=dtype
)
elif prediction.get("method") == "none":
pass
else:
raise NotImplementedError("The prediction model {} is not implemented.".format(prediction["method"]))

Expand Down Expand Up @@ -207,8 +210,11 @@ def forward(self, data: AtomicDataDict.Type):
data = self.embedding(data)
if self.overlap:
data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY]
data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

if not self.prediction.get("method") == "none":
data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

data = self.hamiltonian(data)

if self.overlap:
Expand Down
3 changes: 3 additions & 0 deletions dptb/nn/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .emb import Embedding
from .se2 import SE2Descriptor
from .baseline import BASELINE
from .mpnn import MPNN
from .deephe3 import N3DeePH

__all__ = [
"Descriptor",
"SE2Descriptor",
"Identity",
"N3DeePH",
]
25 changes: 16 additions & 9 deletions dptb/nn/embedding/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def __init__(
self.rc = rc

self.p = p
self.node_emb_layer = _NODE_EMB(rc=rc, p=p, n_axis=n_axis, n_basis=n_basis, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_atom=n_atom, radial_net=radial_net, dtype=dtype, device=device)
self.node_emb_layer = _NODE_EMB(rc=self.rc, p=p, n_axis=n_axis, n_basis=n_basis, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_atom=n_atom, radial_net=radial_net, dtype=dtype, device=device)
self.layers = torch.nn.ModuleList([])
for i in range(n_layer):
self.layers.append(BaselineLayer(rc=rc, p=p, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_axis=n_axis, n_hidden=n_axis*n_sqrt_radial, hidden_net=hidden_net, radial_net=radial_net, dtype=dtype, device=device))
for _ in range(n_layer):
self.layers.append(BaselineLayer(n_atom=n_atom, rc=self.rc, p=p, n_radial=n_radial, n_sqrt_radial=n_sqrt_radial, n_axis=n_axis, n_hidden=n_axis*n_sqrt_radial, hidden_net=hidden_net, radial_net=radial_net, dtype=dtype, device=device))
self.onehot = OneHotAtomEncoding(num_types=n_atom, set_features=False)

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
Expand Down Expand Up @@ -227,6 +227,7 @@ def __init__(
n_radial: int,
n_sqrt_radial: int,
n_axis: int,
n_atom: int,
n_hidden: int,
radial_net: dict={},
hidden_net: dict={},
Expand Down Expand Up @@ -261,19 +262,24 @@ def __init__(
self.dtype = dtype

def forward(self, env_length, edge_length, edge_index, env_index, env_radial, edge_radial, node_emb, env_hidden, edge_hidden):
# n_env = env_index.shape[1]
# n_edge = edge_index.shape[1]
# env_attr = atom_attr[env_index].transpose(1,0).reshape(n_env,-1)
# edge_attr = atom_attr[edge_index].transpose(1,0).reshape(n_edge,-1)

env_weight = self.mlp_emb(env_radial)
# node_emb can descripe the node very well
_node_emb = self.propagate(env_index, node_emb=node_emb[env_index[1]], env_weight=env_weight) # [N_atom, D, 3]
node_emb = 0.89442719 * node_emb + 0.4472 * self.propagate(env_index, node_emb=node_emb[env_index[1]], env_weight=env_weight) # [N_atom, D, 3]
# import matplotlib.pyplot as plt
# fig = plt.figure(figsize=(15,4))
# plt.plot(node_emb.detach().T)
# plt.title("node_emb")
# plt.show()

# env_hidden 长得太像了
env_hidden = self.mlp_hid(torch.cat([_node_emb[env_index[0]], env_hidden], dim=-1))
edge_hidden = self.mlp_hid(torch.cat([_node_emb[edge_index[0]], edge_hidden], dim=-1))
node_emb = _node_emb + node_emb
env_hidden = self.mlp_hid(torch.cat([node_emb[env_index[0]], env_hidden], dim=-1))
edge_hidden = self.mlp_hid(torch.cat([node_emb[edge_index[0]], edge_hidden], dim=-1))
# node_emb = _node_emb + node_emb

# import matplotlib.pyplot as plt
# fig = plt.figure(figsize=(15,4))
Expand All @@ -283,8 +289,8 @@ def forward(self, env_length, edge_length, edge_index, env_index, env_radial, ed

ud_env = polynomial_cutoff(x=env_length, r_max=self.rc, p=self.p).reshape(-1, 1)
ud_edge = polynomial_cutoff(x=edge_length, r_max=self.rc, p=self.p).reshape(-1, 1)
env_radial = ud_env * self.edge_layer_norm(self.mlp_radial(torch.cat([env_radial, env_hidden], dim=-1)))
edge_radial = ud_edge * self.edge_layer_norm(self.mlp_radial(torch.cat([edge_radial, edge_hidden], dim=-1)))
env_radial = 0.89442719 * env_radial + 0.4472 * ud_env * self.edge_layer_norm(self.mlp_radial(torch.cat([env_radial, env_hidden], dim=-1)))
edge_radial = 0.89442719 * edge_radial + 0.4472 * ud_edge * self.edge_layer_norm(self.mlp_radial(torch.cat([edge_radial, edge_hidden], dim=-1)))

return env_radial, env_hidden, edge_radial, edge_hidden, node_emb

Expand All @@ -304,5 +310,6 @@ def update(self, aggr_out):
_type_
_description_
"""

aggr_out = aggr_out.reshape(aggr_out.shape[0], -1)
return self.node_layer_norm(aggr_out) # [N, D*D]
Loading

0 comments on commit ebf326f

Please sign in to comment.