Skip to content

Commit

Permalink
simplify symmetry module to reduce the memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Aug 30, 2024
1 parent 85aa6ae commit 9ec44a4
Showing 1 changed file with 52 additions and 38 deletions.
90 changes: 52 additions & 38 deletions pyxtal/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,29 @@
filtered_coords_euclidean,
)


def rf(package_name, resource_path):
package_path = importlib.util.find_spec(
package_name).submodule_search_locations[0]
return os.path.join(package_path, resource_path)

class SubgroupData:
def __init__(self):
self._t_subgroup = None
self._k_subgroup = None

# ------------------------------ Constants ---------------------------------------
def get_t_subgroup(self):
if self._t_subgroup is None:
self._t_subgroup = loadfn(rf("pyxtal", "database/t_subgroup.json"))
return self._t_subgroup

wyckoff_df = read_csv(rf("pyxtal", "database/wyckoff_list.csv"))
wyckoff_symmetry_df = read_csv(rf("pyxtal", "database/wyckoff_symmetry.csv"))
wyckoff_generators_df = read_csv(
rf("pyxtal", "database/wyckoff_generators.csv"))
layer_df = read_csv(rf("pyxtal", "database/layer.csv"))
layer_symmetry_df = read_csv(rf("pyxtal", "database/layer_symmetry.csv"))
layer_generators_df = read_csv(rf("pyxtal", "database/layer_generators.csv"))
rod_df = read_csv(rf("pyxtal", "database/rod.csv"))
rod_symmetry_df = read_csv(rf("pyxtal", "database/rod_symmetry.csv"))
rod_generators_df = read_csv(rf("pyxtal", "database/rod_generators.csv"))
point_df = read_csv(rf("pyxtal", "database/point.csv"))
point_symmetry_df = read_csv(rf("pyxtal", "database/point_symmetry.csv"))
point_generators_df = read_csv(rf("pyxtal", "database/point_generators.csv"))
symbols = loadfn(rf("pyxtal", "database/symbols.json"))
t_subgroup = loadfn(rf("pyxtal", "database/t_subgroup.json"))
k_subgroup = loadfn(rf("pyxtal", "database/k_subgroup.json"))
wyc_sets = loadfn(rf("pyxtal", "database/wyckoff_sets.json"))
def get_k_subgroup(self):
if self._k_subgroup is None:
self._k_subgroup = loadfn(rf("pyxtal", "database/k_subgroup.json"))
return self._k_subgroup
# ------------------------------ Constants ---------------------------------------
#t_subgroup = loadfn(rf("pyxtal", "database/t_subgroup.json"))
#k_subgroup = loadfn(rf("pyxtal", "database/k_subgroup.json"))
subgroup_data = SubgroupData()
hall_table = read_csv(rf("pyxtal", "database/HM_Full.csv"), sep=",")
# The map between spglib default space group and hall numbers
spglib_hall_numbers = [
Expand Down Expand Up @@ -828,7 +825,7 @@ def list_wyckoff_combinations(self, numIons, quick=False, numWp=(None, None), Nm
Args:
numIons (list): [12, 8]
quick ()Boolean): quickly generate some solutions
quick (Boolean): quickly generate some solutions
numWp (tuple): (min_wp, max_wp)
Nmax: maximumly allowed combinations
Expand Down Expand Up @@ -1044,6 +1041,7 @@ def get_alternatives(self):
Get the alternative settings as a dictionary
"""
if self.dim == 3:
wyc_sets = loadfn(rf("pyxtal", "database/wyckoff_sets.json"))
return wyc_sets[str(self.number)]
else:
msg = "Only supports the subgroups for space group"
Expand All @@ -1054,6 +1052,7 @@ def get_max_k_subgroup(self):
Returns the maximal k-subgroups as a dictionary
"""
if self.dim == 3:
k_subgroup = subgroup_data.get_k_subgroup()
return k_subgroup[str(self.number)]
else:
msg = "Only supports the subgroups for space group"
Expand All @@ -1064,6 +1063,7 @@ def get_max_t_subgroup(self):
Returns the maximal t-subgroups as a dictionary
"""
if self.dim == 3:
t_subgroup = subgroup_data.get_t_subgroup()
return t_subgroup[str(self.number)]
else:
msg = "Only supports the subgroups for space group"
Expand Down Expand Up @@ -1239,8 +1239,8 @@ def get_max_subgroup_numbers(self, max_cell=9):
"""
groups = []
if self.dim == 3:
sub_k = k_subgroup[str(self.number)]
sub_t = t_subgroup[str(self.number)]
sub_k = self.get_max_k_subgroup()
sub_t = self.get_max_t_subgroup()
k = sub_k["subgroup"]
t = sub_t["subgroup"]
for i, n in enumerate(t):
Expand Down Expand Up @@ -1541,6 +1541,9 @@ def add_k_transitions(self, path, n=1):
a list of maximal subgroup chains with extra k type transitions
"""

k_subgroup = subgroup_data.get_k_subgroup()
t_subgroup = subgroup_data.get_t_subgroup()

if n != 1:
print("only 1 extra k type supported at this time")
return None
Expand Down Expand Up @@ -1576,6 +1579,9 @@ def path_to_general_wp(self, index=1, max_steps=1):
a list of (g_types, subgroup_id, spg_number, wp_list (optional))
"""
# label = [str(self[index].multiplicity) + self[index].letter]
k_subgroup = subgroup_data.get_k_subgroup()
t_subgroup = subgroup_data.get_t_subgroup()

label = [self[index].get_label()]
potential = [[(None, None, self.number, label)]]
solutions = []
Expand Down Expand Up @@ -1750,7 +1756,6 @@ def list_groups(cls, dim=3):
group: the group symbol or international number
dim: the periodic dimension of the group
"""

import pandas as pd

keys = {
Expand All @@ -1759,7 +1764,9 @@ def list_groups(cls, dim=3):
1: "rod_group",
0: "point_group",
}
data = symbols[keys[dim]]

group_symbols = loadfn(rf("pyxtal", "database/symbols.json"))
data = group_symbols[keys[dim]]
index = range(1, len(data) + 1)
df = pd.DataFrame(index=index, data=data, columns=[keys[dim]])
pd.set_option("display.max_rows", len(df))
Expand Down Expand Up @@ -3740,13 +3747,15 @@ def get_wyckoffs(num, organized=False, dim=3):
a list of Wyckoff positions, each of which is a list of SymmOp's
"""
if dim == 3:
wyckoff_strings = eval(wyckoff_df["0"][num])
df = read_csv(rf("pyxtal", "database/wyckoff_list.csv"))
elif dim == 2:
wyckoff_strings = eval(layer_df["0"][num])
df = read_csv(rf("pyxtal", "database/layer.csv"))
elif dim == 1:
wyckoff_strings = eval(rod_df["0"][num])
df = read_csv(rf("pyxtal", "database/rod.csv"))
elif dim == 0:
wyckoff_strings = eval(point_df["0"][num])
df = read_csv(rf("pyxtal", "database/point.csv"))

wyckoff_strings = eval(df["0"][num])

wyckoffs = []
for x in wyckoff_strings:
Expand Down Expand Up @@ -3786,13 +3795,15 @@ def get_wyckoff_symmetry(num, dim=3):
point in each Wyckoff position
"""
if dim == 3:
symmetry_strings = eval(wyckoff_symmetry_df["0"][num])
symmetry_df = read_csv(rf("pyxtal", "database/wyckoff_symmetry.csv"))
elif dim == 2:
symmetry_strings = eval(layer_symmetry_df["0"][num])
symmetry_df = read_csv(rf("pyxtal", "database/layer_symmetry.csv"))
elif dim == 1:
symmetry_strings = eval(rod_symmetry_df["0"][num])
symmetry_df = read_csv(rf("pyxtal", "database/rod_symmetry.csv"))
elif dim == 0:
symmetry_strings = eval(point_symmetry_df["0"][num])
symmetry_df = read_csv(rf("pyxtal", "database/point_symmetry.csv"))

symmetry_strings = eval(symmetry_df["0"][num])

symmetry = []
# Loop over Wyckoff positions
Expand Down Expand Up @@ -3826,16 +3837,18 @@ def get_generators(num, dim=3):
a 2d list of symmop objects [[wp0], [wp1], ... ]
"""

generators = []
if dim == 3:
generator_strings = eval(wyckoff_generators_df["0"][num])
generators_df = read_csv(rf("pyxtal", "database/wyckoff_generators.csv"))
elif dim == 2:
generator_strings = eval(layer_generators_df["0"][num])
generators_df = read_csv(rf("pyxtal", "database/layer_generators.csv"))
elif dim == 1:
generator_strings = eval(rod_generators_df["0"][num])
generators_df = read_csv(rf("pyxtal", "database/rod_generators.csv"))
elif dim == 0:
generator_strings = eval(point_generators_df["0"][num])
generators_df = read_csv(rf("pyxtal", "database/point_generators.csv"))

generator_strings = eval(generators_df["0"][num])

generators = []
# Loop over Wyckoff positions
for x in generator_strings:
generators.append([])
Expand Down Expand Up @@ -4004,7 +4017,8 @@ def get_symbol_and_number(input_group, dim=3):
0: "point_group",
}

lists = symbols[keys[dim]]
group_symbols = loadfn(rf("pyxtal", "database/symbols.json"))
lists = group_symbols[keys[dim]]
number = None
symbol = None
if dim not in [0, 1, 2, 3]:
Expand Down

0 comments on commit 9ec44a4

Please sign in to comment.