Skip to content

Commit

Permalink
simplify the output print
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Jul 24, 2024
1 parent ba87be5 commit ea73098
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 116 deletions.
47 changes: 11 additions & 36 deletions pyxtal/optimize/GA.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,11 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
ref_pxrd: reference pxrd profile in 2D array
Returns:
(generation, np.min(engs), None, None, None, 0, len(engs))
success_rate or None
"""
if ref_pmg is not None:
ref_pmg.remove_species("H")

self.best_reps = []
self.reps = []
self.engs = []
self.matches = 0

# Related to the FF optimization
N_added = 0

Expand Down Expand Up @@ -281,43 +276,23 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
self.N_struc = len(self.engs)

# Update the FF parameters if necessary
# import sys; sys.exit()
if self.ff_opt:
N_max = min([int(self.N_pop * 0.6), 50])
ids = np.argsort(engs)
xtals = self.select_xtals(current_xtals, ids, N_max)
print("Select Good structures for FF optimization", len(xtals))
N_added = self.ff_optimization(xtals, N_added)

else:
if self.early_quit:
match = self.early_termination(current_xtals, current_matches,
current_engs, current_tags, ref_pmg, ref_eng)

if match is not None:
print("Early termination")
self.logging.info("Early termination")
return match
else:
for m in current_matches:
if m:
self.matches += 1
success_rate = self.matches / self.N_struc * 100
gen_out = f"Success rate at Gen {gen:3d}: "
gen_out += f"{success_rate:7.4f}%"
self.logging.info(gen_out)
print(gen_out)

if success_rate > 2.5 or self.matches > 10: # to check later
msg = f"Early termination with a high success rate"
print(msg)
self.logging.info(msg)
return success_rate

if not self.ff_opt and not self.early_quit:
return success_rate
else:
return None
success_rate = self.success_count(gen, current_xtals, current_matches, current_tags, engs, ref_pmg)
gen_out = f"Success rate at Gen {gen:3d}: "
gen_out += f"{success_rate:7.4f}%"
self.logging.info(gen_out)
print(gen_out)

if self.early_termination(success_rate):
return success_rate

return None

def _selTournament(self, fitness, factor=0.35):
"""
Expand Down
43 changes: 9 additions & 34 deletions pyxtal/optimize/PSO.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
if ref_pmg is not None:
ref_pmg.remove_species("H")

self.best_reps = []
self.reps = []
self.engs = []
self.matches = 0

# Related to the FF optimization
N_added = 0

Expand Down Expand Up @@ -281,7 +276,6 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
self.N_struc = len(self.engs)

# Update the FF parameters if necessary
# import sys; sys.exit()
if self.ff_opt:
N_max = min([int(self.N_pop * 0.6), 50])
ids = np.argsort(engs)
Expand All @@ -290,35 +284,16 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
N_added = self.ff_optimization(xtals, N_added)

else:
match = self.early_termination(current_xtals, current_matches,
current_engs, current_tags, ref_pmg, ref_eng)
if self.early_quit:
if match is not None:
print("Early termination")
self.logging.info("Early termination")
return match
else:
#print('debug', gen, current_matches)
for m in current_matches:
if m:
self.matches += 1
success_rate = self.matches / self.N_struc * 100
gen_out = f"Success rate at Gen {gen:3d}: "
gen_out += f"{success_rate:7.4f}%"
self.logging.info(gen_out)
print(gen_out)

if success_rate > 2.5 or self.matches > 10: # to check later
msg = f"Early termination with a high success rate"
print(msg)
self.logging.info(msg)
return success_rate

if not self.ff_opt and not self.early_quit:
return success_rate
else:
return None
success_rate = self.success_count(gen, current_xtals, current_matches, current_tags, engs, ref_pmg)
gen_out = f"Success rate at Gen {gen:3d}: "
gen_out += f"{success_rate:7.4f}%"
self.logging.info(gen_out)
print(gen_out)

if self.early_termination(success_rate):
return success_rate

return None

if __name__ == "__main__":
import argparse
Expand Down
124 changes: 78 additions & 46 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(
else:
self.random_state = np.random.default_rng(random_state)



# Molecular information
self.smile = smiles
self.smiles = self.smile.split(".") # list
Expand Down Expand Up @@ -186,9 +184,11 @@ def __init__(

# I/O stuff
self.early_quit = early_quit
self.N_min_matches = 10 # The min_num_matches for early termination
self.E_max = E_max
self.tag = tag
self.cif = cif
self.matched_cif = self.workdir + "/" + "matched.cif"
if cif is not None:
with open(self.workdir + "/" + cif, "w") as f:
f.writelines(str(self))
Expand All @@ -199,6 +199,13 @@ def __init__(
logging.basicConfig(format="%(asctime)s| %(message)s", filename=self.log_file, level=logging.INFO)
self.logging = logging

# Some neccessary trackers
self.matches = []
self.best_reps = []
self.reps = []
self.engs = []


def __str__(self):
s = "\n-------Global Crystal Structure Prediction------"
s += f"\nsmile : {self.smile:s}"
Expand Down Expand Up @@ -251,6 +258,54 @@ def select_xtals(self, ref_xtals, ids, N_max):
# xtals = [xtal.to_ase(resort=False) for xtal in xtals]
return xtals


def success_count(self, gen, xtals, matches, tags, engs, ref_pmg):
"""
To wrap up the matched results and count success rate.
Args:
gen (int): current generation index
xtals (list): list of xtals
matches (list): list of matches [True, False, ..]
tags (list): 'random' or 'mutation'
engs (list): list of engs
ref_pmg: reference pymatgen structure
Return:
success_rate
"""
for i, match in enumerate(matches):
if match:
xtal, tag = xtals[i], tags[i]
with open(self.matched_cif, "a+") as f:
res = self._print_match(xtal, ref_pmg)
e, d1, d2 = engs[i], res[0], res[1]
label = self.tag + "-g" + str(gen) + "-p" + str(i)
label += f"-e{e:8.3f}-t{tag:s}-{d1:4.2f}-{d2:4.2f}"
f.writelines(xtal.to_file(header=label))
self.matches.append((i, xtal, e, d1, d2, tag))

success_rate = len(self.matches) / self.N_struc * 100
return success_rate

def early_termination(self, success_rate):
"""
Check if the calculation can be terminated early
"""
if success_rate > 0:
if self.early_quit:
msg = f"Early termination since a match is found"
print(msg)
self.logging.info(msg)
return True

elif success_rate > 2.5 or len(self.matches) > self.N_min_matches:
msg = f"Early termination with a high success rate"
print(msg)
self.logging.info(msg)
return True
return False

def ff_optimization(self, xtals, N_added, N_min=50, dE=2.5, FMSE=2.5):
"""
Optimize the current FF based on newly explored data
Expand Down Expand Up @@ -500,50 +555,6 @@ def prepare_chm_info(self, params0, params1=None, suffix="calc/pyxtal0"):
# Info
self.atom_info = ase_with_ff.get_atom_info()

def early_termination(self, xtals, matches, engs, tags, ref_pmg, ref_eng):
"""
Exit if a match is found
"""

e, d1, d2 = 10000, 0, 0
match_id = None

# Gather all matched results
if ref_pmg is not None:
for id, match in enumerate(matches):
if match and engs[id] < e:
xtal = xtals[id]
res = self._print_match(xtal, ref_pmg)
if res[0] is not None:
e, match_id, d1, d2 = engs[id], id, res[0], res[1]

if match_id is not None:
all_engs = np.sort(np.array(self.engs))
rank = len(all_engs[all_engs < (e - 0.001)]) + 1
tag = tags[match_id][0]
done = False
if rank / self.N_struc < 0.5:
done = True
else:
if ref_eng is not None:
if e < ref_eng + 0.2:
done = True
else:
done = True
if done:
return {
"energy": e,
"tag": tag,
"l_rms": d1,
"a_rms": d2,
"rank": rank,
}
return None
return None

else:
return None

def get_label(self, i):
if i < 10:
folder = f"cpu00{i}"
Expand All @@ -553,6 +564,27 @@ def get_label(self, i):
folder = f"cpu0{i}"
return folder

def print_matches(self, header=None):
"""
Formatted output for the matched structures including xtal rep and eng ranking
"""
all_engs = np.sort(np.array(self.engs))
ranks = []
for match_data in self.matches:
(id, xtal, e, d1, d2, tag) = match_data
rep0 = xtal.get_1D_representation()
if header is not None:
strs = header
else:
strs = ""
strs = rep0.to_string(eng=xtal.energy / sum(xtal.numMols))
strs += f"{d1:6.3f}{d2:6.3f} Match "
rank = len(all_engs[all_engs < (e - 1e-3)]) + 1
strs += f"{rank:d}/{self.N_struc:d} {tag:s}"
print(strs)
ranks.append(rank)
return min(ranks)

def _print_match(self, xtal, ref_pmg):
"""
print the matched structure
Expand Down

0 comments on commit ea73098

Please sign in to comment.