diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py index 0117bf63..33900b3e 100644 --- a/dptb/nn/nnsk.py +++ b/dptb/nn/nnsk.py @@ -134,6 +134,7 @@ def __init__( strain=hasattr(self, "strain_param"),soc=hasattr(self, "soc_param")) if overlap: self.overlap = SKHamiltonian(idp_sk=self.idp_sk, onsite=False, edge_field=AtomicDataDict.EDGE_OVERLAP_KEY, node_field=AtomicDataDict.NODE_OVERLAP_KEY, dtype=self.dtype, device=self.device) + self.register_buffer("ovp_factor", torch.tensor(1.0, dtype=self.dtype, device=self.device)) self.idp = self.hamiltonian.idp if freeze: @@ -188,7 +189,7 @@ def freezefunc(self, freeze: Union[bool,str,list]): raise ValueError("freeze is True, all parameters should frozen. But the frozen_params != all model.named_parameters. Please check the freeze tag.") log.info(f'The {frozen_params} are frozen!') - def push_decay(self, rs_thr: float=0., rc_thr: float=0., w_thr: float=0., period:int=100): + def push_decay(self, rs_thr: float=0., rc_thr: float=0., w_thr: float=0., ovp_thr: float=0., period:int=100): """Push the soft cutoff function Parameters @@ -207,6 +208,8 @@ def push_decay(self, rs_thr: float=0., rc_thr: float=0., w_thr: float=0., period self.hopping_options["w"] += w_thr if abs(rc_thr) > 0: self.hopping_options["rc"] += rc_thr + if abs(ovp_thr) > 0 and self.ovp_factor >=ovp_thr: + self.ovp_factor -= ovp_thr self.model_options["nnsk"]["hopping"] = self.hopping_options @@ -278,7 +281,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # the overlap tag now is only designed to be used in the NRL-TB case. In the future, we may need to change this. paraconst = edge_number[0].eq(edge_number[1]).float().view(-1, 1) * equal_orbpair.unsqueeze(0) - data[AtomicDataDict.EDGE_OVERLAP_KEY] = self.overlap_fn.get_sksij( + data[AtomicDataDict.EDGE_OVERLAP_KEY] = self.ovp_factor * self.overlap_fn.get_sksij( rij=data[AtomicDataDict.EDGE_LENGTH_KEY], paraArray=self.overlap_param[edge_index], paraconst=paraconst, @@ -795,7 +798,7 @@ def to_json(self,version=2): "overlap": is_overlap, } - if version ==2: + if version == 2: ckpt.update({"model_options": self.model_options, "common_options": common_options}) diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 1589351a..31f1dd6e 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -597,12 +597,14 @@ def push(): doc_rs_thr = "The step size for cutoff value for smooth function in the nnsk anlytical formula." doc_rc_thr = "The step size for cutoff value for smooth function in the nnsk anlytical formula." doc_w_thr = "The step size for decay factor w." + doc_ovp_thr = "The step size for overlap reduction" doc_period = "the interval of iterations to modify the rs w values." return Argument("push", [bool,dict], sub_fields=[ Argument("rs_thr", [int,float], optional=True, default=0., doc=doc_rs_thr), Argument("rc_thr", [int,float], optional=True, default=0., doc=doc_rc_thr), Argument("w_thr", [int,float], optional=True, default=0., doc=doc_w_thr), + Argument("ovp_thr", [int,float], optional=True, default=0., doc=doc_ovp_thr), Argument("period", int, optional=True, default=100, doc=doc_period), ], sub_variants=[], optional=True, default=False, doc="The parameters to define the push the soft cutoff of nnsk model.") diff --git a/dptb/utils/config_sk.py b/dptb/utils/config_sk.py index f82de028..bdf6728c 100644 --- a/dptb/utils/config_sk.py +++ b/dptb/utils/config_sk.py @@ -56,7 +56,7 @@ }, "freeze": False, "std": 0.01, - "push": False or {"w_thr": 0.0,"period": 1,"rs_thr": 0.0,"rc_thr": 0.0}, + "push": False or {"w_thr": 0.0,"period": 1,"rs_thr": 0.0, "ovp_thr": 0.0, "rc_thr": 0.0}, } }, "data_options": {