Skip to content

Commit

Permalink
add tests (#89)
Browse files Browse the repository at this point in the history
* test: create test_typemapper.py

* update transforms.py and test_typemapper.py

* test: update test_typemapper.py

* update transforms.py and creat test_bondmapper

* update transforms.py

* update transforms.py

* update constants.py

* add test_orbital_mapper_init_str_spdf

* add assert for init sktb using string

* update transforms.py

* add test_orbitalmapper.py

* test: update test_orbitalmapper.py
  • Loading branch information
QG-phy authored Mar 21, 2024
1 parent 330d69a commit fcf25f8
Show file tree
Hide file tree
Showing 5 changed files with 889 additions and 26 deletions.
101 changes: 79 additions & 22 deletions dptb/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,15 @@ def transform(self, atomic_numbers):
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return self._Z_to_index.to(device=atomic_numbers.device)[
atomic_numbers - self._min_Z
]
types = self._Z_to_index.to(device=atomic_numbers.device)[atomic_numbers - self._min_Z]

if -1 in types:
bad_set = set(torch.unique(atomic_numbers).cpu().tolist()) - self._valid_set
raise ValueError(
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return types

def untransform(self, atom_types):
"""Transform atom types back into atomic numbers"""
Expand Down Expand Up @@ -288,15 +294,28 @@ def transform_bond(self, iatomic_numbers, jatomic_numbers):
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return self._ZZ_to_index.to(device=iatomic_numbers.device)[
iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z
]

bondtypes = self._ZZ_to_index.to(device=iatomic_numbers.device)[iatomic_numbers - self._min_Z,
jatomic_numbers - self._min_Z]

if -1 in bondtypes:
bad_set1 = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set
bad_set2 = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set
bad_set = bad_set1.union(bad_set2)
raise ValueError(
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return bondtypes

def transform_reduced_bond(self, iatomic_numbers, jatomic_numbers):

if iatomic_numbers.device != jatomic_numbers.device:
raise ValueError("iatomic_numbers and jatomic_numbers should be on the same device!")

if not torch.all((iatomic_numbers -jatomic_numbers)<=0):
raise ValueError("iatomic_numbers[i] should <= jatomic_numbers[i]")

if iatomic_numbers.min() < self._min_Z or iatomic_numbers.max() > self._max_Z:
bad_set = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set
raise ValueError(
Expand All @@ -309,9 +328,19 @@ def transform_reduced_bond(self, iatomic_numbers, jatomic_numbers):
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return self._ZZ_to_reduced_index.to(device=iatomic_numbers.device)[
iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z
]

red_bondtypes = self._ZZ_to_reduced_index.to(device=iatomic_numbers.device)[
iatomic_numbers - self._min_Z, jatomic_numbers - self._min_Z]

if -1 in red_bondtypes:
bad_set1 = set(torch.unique(iatomic_numbers).cpu().tolist()) - self._valid_set
bad_set2 = set(torch.unique(jatomic_numbers).cpu().tolist()) - self._valid_set
bad_set = bad_set1.union(bad_set2)
raise ValueError(
f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
)

return red_bondtypes

def untransform_atom(self, atom_types):
"""Transform atom types back into atomic numbers"""
Expand Down Expand Up @@ -404,6 +433,8 @@ def __init__(
when list, "2s" indicate a "s" orbital in the second shell.
when str, "2s" indicates two s orbitals,
"2s2p3d4f" is equivilent to ["1s","2s", "1p", "2p", "1d", "2d", "3d", "1f"]
Note: the list basis can be used for both e3tb and sktb. but the string basis can only be used for e3tb.
"""

#TODO: use OrderedDict to fix the order of the dict used as index map
Expand All @@ -421,23 +452,36 @@ def __init__(
raise ValueError

if isinstance(self.basis[self.type_names[0]], str):
orbtype_count = {"s":0, "p":0, "d":0, "f":0}
assert method == "e3tb", "The method should be e3tb when the basis is given as string."
all_orb_types = []
for iatom, ibasis in self.basis.items():
letters = [letter for letter in ibasis if letter.isalpha()]
all_orb_types = all_orb_types + letters
if len(letters) != len(set(letters)):
raise ValueError(f"Duplicate orbitals found in the basis {ibasis} of atom {iatom}")
all_orb_types = set(all_orb_types)
orbtype_count = {"s":0, "p":0, "d":0, "f":0, "g":0, "h":0}

if not all_orb_types.issubset(set(orbtype_count.keys())):
raise ValueError(f"Invalid orbital types {all_orb_types} found in the basis. now only support {set(orbtype_count.keys())}.")

orbs = map(lambda bs: re.findall(r'[1-9]+[A-Za-z]', bs), self.basis.values())
for ib in orbs:
for io in ib:
assert len(io) == 2
if int(io[0]) > orbtype_count[io[1]]:
orbtype_count[io[1]] = int(io[0])
# split into list basis
basis = {k:[] for k in self.type_names}
for ib in self.basis.keys():
for io in ["s", "p", "d", "f"]:
for io in ["s", "p", "d", "f", "g", "h"]:
if io in self.basis[ib]:
basis[ib].extend([str(i)+io for i in range(1, int(re.findall(r'[1-9]+'+io, self.basis[ib])[0][0])+1)])
self.basis = basis

elif isinstance(self.basis[self.type_names[0]], list):
nb = len(self.type_names)
orbtype_count = {"s":[0]*nb, "p":[0]*nb, "d":[0]*nb, "f":[0]*nb}
orbtype_count = {"s":[0]*nb, "p":[0]*nb, "d":[0]*nb, "f":[0]*nb, "g":[0]*nb, "h":[0]*nb}
for ib, bt in enumerate(self.type_names):
for io in self.basis[bt]:
orb = re.findall(r'[A-Za-z]', io)[0]
Expand All @@ -447,12 +491,23 @@ def __init__(
orbtype_count[ko] = max(orbtype_count[ko])

self.orbtype_count = orbtype_count
self.full_basis_norb = 1 * orbtype_count["s"] + 3 * orbtype_count["p"] + 5 * orbtype_count["d"] + 7 * orbtype_count["f"]

full_basis_norb = 0
for ko in orbtype_count.keys():
assert ko in anglrMId
full_basis_norb = full_basis_norb + (2 * anglrMId[ko] + 1) * orbtype_count[ko]
# self.full_basis_norb = 1 * orbtype_count["s"] + 3 * orbtype_count["p"] + 5 * orbtype_count["d"] + 7 * orbtype_count["f"]
self.full_basis_norb = full_basis_norb

if self.method == "e3tb":
self.reduced_matrix_element = int(((orbtype_count["s"] + 9 * orbtype_count["p"] + 25 * orbtype_count["d"] + 49 * orbtype_count["f"]) + \
self.full_basis_norb ** 2)/2) # reduce onsite elements by blocks. we cannot reduce it by element since the rme will pass into CG basis to form the whole block
# The total number of matrix elements in the full basis self.full_basis_norb ** 2
# since the onsite block can not be reduced, orbtype_count["s"] + 9 * orbtype_count["p"] + 25 * orbtype_count["d"] + 49 * orbtype_count["f"])
# Then the reduce is to sum of full and onsite block and divide by 2
total_onsite_block_elements = 0
for ko in orbtype_count.keys():
total_onsite_block_elements += orbtype_count[ko] * (2 * anglrMId[ko] + 1)**2
self.reduced_matrix_element = int((self.full_basis_norb ** 2 + total_onsite_block_elements)/2)
#self.reduced_matrix_element = int(((orbtype_count["s"] + 9 * orbtype_count["p"] + 25 * orbtype_count["d"] + 49 * orbtype_count["f"]) + \
# self.full_basis_norb ** 2)/2) # reduce onsite elements by blocks. we cannot reduce it by element since the rme will pass into CG basis to form the whole block
else:
# two factor: this outside one is the number of min(l,l')+1, ie. the number of sk integrals for each orbital pair.
# the inside one the type of bond considering the interaction between different orbitals. s-p -> p-s. there are 2 types of bond. and 1 type of s-s.
Expand All @@ -473,6 +528,8 @@ def __init__(
) + \
4 * (orbtype_count["f"] * orbtype_count["f"])

assert orbtype_count['g'] + orbtype_count['h'] == 0, "g and h orbitals are not supported in sktb method."

self.reduced_matrix_element = self.reduced_matrix_element + orbtype_count["s"] + 2*orbtype_count["p"] + 3*orbtype_count["d"] + 4*orbtype_count["f"]
self.reduced_matrix_element = int(self.reduced_matrix_element / 2)
self.n_onsite_Es = orbtype_count["s"] + orbtype_count["p"] + orbtype_count["d"] + orbtype_count["f"]
Expand All @@ -486,15 +543,15 @@ def __init__(

# TODO: get full basis set
full_basis = []
for io in ["s", "p", "d", "f"]:
for io in ["s", "p", "d", "f", "g", "h"]:
full_basis = full_basis + [str(i)+io for i in range(1, orbtype_count[io]+1)]
self.full_basis = full_basis

# TODO: get the mapping from list basis to full basis
self.basis_to_full_basis = {}
self.atom_norb = torch.zeros(len(self.type_names), dtype=torch.long, device=self.device)
for ib in self.basis.keys():
count_dict = {"s":0, "p":0, "d":0, "f":0}
count_dict = {"s":0, "p":0, "d":0, "f":0, "g":0, "h":0}
self.basis_to_full_basis.setdefault(ib, {})
for o in self.basis[ib]:
io = re.findall(r"[a-z]", o)[0]
Expand Down Expand Up @@ -527,7 +584,7 @@ def __init__(
assert (self.mask_to_basis.sum(dim=1).int()-self.atom_norb).abs().sum() <= 1e-6

self.get_orbpair_maps()
# the mask to map the full basis reduced matrix element to the original basis reduced matrix element
# the mask to map the full basis edge/node reduced matrix element (erme/nrme) to the original basis reduced matrix element
self.mask_to_erme = torch.zeros(len(self.bond_types), self.reduced_matrix_element, dtype=torch.bool, device=self.device)
self.mask_to_nrme = torch.zeros(len(self.type_names), self.reduced_matrix_element, dtype=torch.bool, device=self.device)
for ib, bb in self.basis.items():
Expand Down Expand Up @@ -574,9 +631,9 @@ def get_orbpairtype_maps(self):

self.orbpairtype_maps = {}
ist = 0
for i, io in enumerate(["s", "p", "d", "f"]):
for i, io in enumerate(["s", "p", "d", "f", "g", "h"]):
if self.orbtype_count[io] != 0:
for jo in ["s", "p", "d", "f"][i:]:
for jo in ["s", "p", "d", "f", "g", "h"][i:]:
if self.orbtype_count[jo] != 0:
orb_pair = io+"-"+jo
il, jl = anglrMId[io], anglrMId[jo]
Expand Down Expand Up @@ -650,7 +707,7 @@ def get_skonsitetype_maps(self):
ist = 0

assert self.method == "sktb", "Only sktb orbitalmapper have skonsite maps."
for i, io in enumerate(["s", "p", "d", "f"]):
for i, io in enumerate(["s", "p", "d", "f", "g", "h"]):
if self.orbtype_count[io] != 0:
il = anglrMId[io]
numonsites = self.orbtype_count[io]
Expand Down
Loading

0 comments on commit fcf25f8

Please sign in to comment.