Skip to content

Commit

Permalink
update deephe3 module
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingCatty committed Nov 30, 2023
1 parent ebf326f commit 2db61f2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
14 changes: 7 additions & 7 deletions dptb/data/interfaces/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def __init__(self):
# 3: [0, 2, 4, 6],
# }

minus_dict = {
1: [0, 2],
2: [1, 3],
3: [0, 2, 4, 6],
}
# minus_dict = {
# 1: [0, 2],
# 2: [1, 3],
# 3: [0, 2, 4, 6],
# }

for k, v in minus_dict.items():
self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m
# for k, v in minus_dict.items():
# self.Us_abacus2deeptb[k][v] *= -1 # add phase (-1)^m

def get_U(self, l):
if l > 3:
Expand Down
2 changes: 1 addition & 1 deletion dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(self, data: AtomicDataDict.Type):
data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

data = self.hamiltonian(data)
# data = self.hamiltonian(data)

if self.overlap:
data = self.edge_prediction_s(data)
Expand Down
8 changes: 3 additions & 5 deletions dptb/nn/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,9 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
HR = data[self.node_field][:, self.idp.nodetype_maps[opairtype]]
HR = HR.reshape(n_node, -1, nL, nR).permute(0,2,3,1)# shape (N, nL, nR, n_pair)

# rme = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
# HR[:,:,:,None,:], dim=(1,2)) # shape (N, n_rme, n_pair)
# rme = rme.transpose(1,2).reshape(n_node, -1)

rme = HR.permute(0,3,1,2).reshape(n_node, -1)
rme = torch.sum(self.cgbasis[opairtype][None,:,:,:,None] * \
HR[:,:,:,None,:], dim=(1,2)) # shape (N, n_rme, n_pair)
rme = rme.transpose(1,2).reshape(n_node, -1)

# the onsite block doesnot have rotation
data[self.node_field][:, self.idp.nodetype_maps[opairtype]] = rme
Expand Down
14 changes: 9 additions & 5 deletions dptb/nnops/_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,15 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
# data[AtomicDataDict.NODE_FEATURES_KEY].masked_fill(~self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY]], 0.)
# data[AtomicDataDict.EDGE_FEATURES_KEY].masked_fill(~self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY]], 0.)

node_mean = ref_data[AtomicDataDict.NODE_FEATURES_KEY].mean(dim=-1, keepdim=True)
edge_mean = ref_data[AtomicDataDict.EDGE_FEATURES_KEY].mean(dim=-1, keepdim=True)
node_weight = 1/((ref_data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean).norm(dim=-1, keepdim=True)+1e-5)
edge_weight = 1/((ref_data[AtomicDataDict.EDGE_FEATURES_KEY]-edge_mean).norm(dim=-1, keepdim=True)+1e-5)

# node_mean = ref_data[AtomicDataDict.NODE_FEATURES_KEY].mean(dim=-1, keepdim=True)
# edge_mean = ref_data[AtomicDataDict.EDGE_FEATURES_KEY].mean(dim=-1, keepdim=True)
# node_weight = 1/((ref_data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean).norm(dim=-1, keepdim=True)+1e-5)
# edge_weight = 1/((ref_data[AtomicDataDict.EDGE_FEATURES_KEY]-edge_mean).norm(dim=-1, keepdim=True)+1e-5)

node_mean = 0.
edge_mean = 0.
node_weight = 1.
edge_weight = 1.

pre = (node_weight*(data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean))[self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
tgt = (node_weight*(ref_data[AtomicDataDict.NODE_FEATURES_KEY]-node_mean))[self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
Expand Down

0 comments on commit 2db61f2

Please sign in to comment.