Skip to content

Commit

Permalink
enable the run of GA and move optimize to base
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Jul 20, 2024
1 parent 22b7721 commit c5e3111
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 105 deletions.
2 changes: 1 addition & 1 deletion pyxtal/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def rdkit_mol_init(self, smile, fix, torsions):
pruneRmsThresh=0.5,
)
N_confs = mol.GetNumConformers()
conf_id = self.random_state.choice(range(N_confs))
conf_id = int(self.random_state.choice(range(N_confs)))
conf = mol.GetConformer(conf_id)
# xyz = conf.GetPositions()
# res = AllChem.MMFFOptimizeMoleculeConfs(mol)
Expand Down
103 changes: 4 additions & 99 deletions pyxtal/optimize/GA.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from __future__ import annotations

import multiprocessing
from concurrent.futures import ProcessPoolExecutor, TimeoutError
from time import time
from typing import TYPE_CHECKING

Expand All @@ -16,7 +14,6 @@
from numpy.random import Generator

from pyxtal.optimize.base import GlobalOptimize
from pyxtal.optimize.common import optimizer_par, optimizer_single
from pyxtal.representation import representation

if TYPE_CHECKING:
Expand Down Expand Up @@ -207,101 +204,8 @@ def run(self, ref_pmg=None, ref_eng=None, ref_pxrd=None):
current_xtals[count] = self._crossover(xtal1, xtal2)
count += 1

# Local optimization (QZ: to move the block to base.py)
args = [
self.randomizer,
self.optimizer,
self.smiles,
self.block,
self.num_block,
self.atom_info,
self.workdir + "/" + "calc",
self.sg,
self.composition,
self.lattice,
self.torsions,
self.molecules,
self.sites,
ref_pmg,
self.matcher,
ref_pxrd,
self.use_hall,
self.skip_ani,
]

gen_results = [(None, None)] * len(current_xtals)
if self.ncpu == 1:
for pop in range(len(current_xtals)):
xtal = current_xtals[pop]
job_tag = self.tag + "-g" + str(gen) + "-p" + str(pop)
mutated = xtal is not None
my_args = [xtal, pop, mutated, job_tag, *args]
gen_results[pop] = optimizer_single(*tuple(my_args))

else:
# parallel process
N_cycle = int(np.ceil(self.N_pop / self.ncpu))
args_lists = []
for i in range(self.ncpu):
id1 = i * N_cycle
id2 = min([id1 + N_cycle, len(current_xtals)])
# os.makedirs(folder, exist_ok=True)
ids = range(id1, id2)
job_tags = [self.tag + "-g" + str(gen) + "-p" + str(id) for id in ids]
xtals = current_xtals[id1:id2]
mutates = [xtal is not None for xtal in xtals]
my_args = [xtals, ids, mutates, job_tags, *args]
args_lists.append(tuple(my_args))

def process_with_timeout(results, timeout):
for result in results:
try:
res_list = result.result(timeout=timeout)
for res in res_list:
(id, xtal, match) = res
gen_results[id] = (xtal, match)
except TimeoutError:
self.logging.info("ERROR: Opt timed out after %d seconds", timeout)
except Exception as e:
self.logging.info("ERROR: An unexpected error occurred: %s", str(e))
return gen_results

def run_with_global_timeout(ncpu, args_lists, timeout, return_dict):
with ProcessPoolExecutor(max_workers=ncpu) as executor:
results = [executor.submit(optimizer_par, *p) for p in args_lists]
gen_results = process_with_timeout(results, timeout)
return_dict["gen_results"] = gen_results

# Set your global timeout value here
global_timeout = self.timeout

# Run multiprocess
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=run_with_global_timeout, args=(self.ncpu, args_lists, global_timeout, return_dict)
)
p.start()
p.join(global_timeout)

if p.is_alive():
self.logging.info("ERROR: Global execution timed out after %d seconds", global_timeout)
# p.terminate()
# Ensure all child processes are terminated
child_processes = psutil.Process(p.pid).children(recursive=True)
self.logging.info("Checking child process total: %d", len(child_processes))
for proc in child_processes:
# self.logging.info("Checking child process ID: %d", pid)
try:
# proc = psutil.Process(pid)
if proc.status() == "running": # is_running():
proc.terminate()
self.logging.info("Terminate abnormal child process ID: %d", proc.pid)
except psutil.NoSuchProcess:
self.logging.info("ERROR: PID %d does not exist", proc.pid)
p.join()

gen_results = return_dict.get("gen_results", {})
# Local optimization
gen_results = self.local_optimization(gen, current_xtals, ref_pmg, ref_pxrd)

# Summary and Ranking
for id, res in enumerate(gen_results):
Expand Down Expand Up @@ -365,6 +269,7 @@ def run_with_global_timeout(ncpu, args_lists, timeout, return_dict):
print(gen_out)

# Save the reps for next move
prev_xtals = current_xtals # ; print(self.engs)
self.min_energy = np.min(np.array(self.engs))
self.N_struc = len(self.engs)

Expand Down Expand Up @@ -394,7 +299,7 @@ def _selTournament(self, fitness, factor=0.35):
individuals, *k* times. The list returned contains
references to the input *individuals*.
"""
IDs = self.random_state.choice(set(range(len(fitness))), int(len(fitness) * factor))
IDs = self.random_state.choice(len(fitness), size=int(len(fitness) * factor), replace=False)
min_fit = np.argmin(fitness[IDs])
return IDs[min_fit]

Expand Down
107 changes: 105 additions & 2 deletions pyxtal/optimize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""

from __future__ import annotations
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, TimeoutError

import logging
import os
Expand All @@ -19,9 +21,10 @@
from ost.parameters import ForceFieldParameters, compute_r2, get_lmp_efs

from pyxtal.molecule import find_rotor_from_smile, pyxtal_molecule
from pyxtal.optimize.common import optimizer, randomizer
from pyxtal.representation import representation
from pyxtal.util import new_struc
from pyxtal.optimize.common import optimizer, randomizer
from pyxtal.optimize.common import optimizer_par, optimizer_single

if TYPE_CHECKING:
from pyxtal.lattice import Lattice
Expand Down Expand Up @@ -442,7 +445,7 @@ def ff_optimization(self, xtals, N_added, N_min=50, dE=2.5, FMSE=2.5):

return N_added

def prepare_chm_info(self, params0, params1=None, suffix="calc/pyxtal"):
def prepare_chm_info(self, params0, params1=None, suffix="calc/pyxtal0"):
"""
TODO: A base classs for optimization
prepar_chm_info with the updated params.
Expand Down Expand Up @@ -649,6 +652,106 @@ def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"):
print(strs)
return False

def local_optimization(self, gen, current_xtals, ref_pmg, ref_pxrd):
"""
perform optimization for each structure in the current generation
"""
args = [
self.randomizer,
self.optimizer,
self.smiles,
self.block,
self.num_block,
self.atom_info,
self.workdir + "/" + "calc",
self.sg,
self.composition,
self.lattice,
self.torsions,
self.molecules,
self.sites,
ref_pmg,
self.matcher,
ref_pxrd,
self.use_hall,
self.skip_ani,
]

gen_results = [(None, None)] * len(current_xtals)
if self.ncpu == 1:
for pop in range(len(current_xtals)):
xtal = current_xtals[pop]
job_tag = self.tag + "-g" + str(gen) + "-p" + str(pop)
mutated = xtal is not None
my_args = [xtal, pop, mutated, job_tag, *args]
gen_results[pop] = optimizer_single(*tuple(my_args))

else:
# parallel process
N_cycle = int(np.ceil(self.N_pop / self.ncpu))
args_lists = []
for i in range(self.ncpu):
id1 = i * N_cycle
id2 = min([id1 + N_cycle, len(current_xtals)])
# os.makedirs(folder, exist_ok=True)
ids = range(id1, id2)
job_tags = [self.tag + "-g" + str(gen) + "-p" + str(id) for id in ids]
xtals = current_xtals[id1:id2]
mutates = [xtal is not None for xtal in xtals]
my_args = [xtals, ids, mutates, job_tags, *args]
args_lists.append(tuple(my_args))

def process_with_timeout(results, timeout):
for result in results:
try:
res_list = result.result(timeout=timeout)
for res in res_list:
(id, xtal, match) = res
gen_results[id] = (xtal, match)
except TimeoutError:
self.logging.info("ERROR: Opt timed out after %d seconds", timeout)
except Exception as e:
self.logging.info("ERROR: An unexpected error occurred: %s", str(e))
return gen_results

def run_with_global_timeout(ncpu, args_lists, timeout, return_dict):
with ProcessPoolExecutor(max_workers=ncpu) as executor:
results = [executor.submit(optimizer_par, *p) for p in args_lists]
gen_results = process_with_timeout(results, timeout)
return_dict["gen_results"] = gen_results

# Set your global timeout value here
global_timeout = self.timeout

# Run multiprocess
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=run_with_global_timeout, args=(self.ncpu, args_lists, global_timeout, return_dict)
)
p.start()
p.join(global_timeout)

if p.is_alive():
self.logging.info("ERROR: Global execution timed out after %d seconds", global_timeout)
# p.terminate()
# Ensure all child processes are terminated
child_processes = psutil.Process(p.pid).children(recursive=True)
self.logging.info("Checking child process total: %d", len(child_processes))
for proc in child_processes:
# self.logging.info("Checking child process ID: %d", pid)
try:
# proc = psutil.Process(pid)
if proc.status() == "running": # is_running():
proc.terminate()
self.logging.info("Terminate abnormal child process ID: %d", proc.pid)
except psutil.NoSuchProcess:
self.logging.info("ERROR: PID %d does not exist", proc.pid)
p.join()

gen_results = return_dict.get("gen_results", {})
return gen_results


if __name__ == "__main__":
print("test")
8 changes: 5 additions & 3 deletions pyxtal/optimize/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def mutator(xtal, smiles, opt_lat, ref_pxrd=None, dr=0.125, random_state=None):
for j in range(3, len(x[i]) - 1):
rad_num = rng.random()
if rad_num < 0.25:
x[i][j] += choice([45.0, 90.0])
x[i][j] += rng.choice([45.0, 90.0])
elif rad_num < 0.5:
x[i][j] *= -1
try:
Expand Down Expand Up @@ -98,8 +98,9 @@ def randomizer(
Returns:
PyXtal object
"""
rng = np.random.default_rng(random_state)
mols = [smi + ".smi" for smi in smiles] if molecules is None else [choice(m) for m in molecules]
sg = choice(sgs)
sg = rng.choice(sgs)
wp = Group(sg, use_hall=True)[0] if use_hall else Group(sg)[0]
mult = len(wp)
numIons = [int(c * mult) for c in comp]
Expand All @@ -118,7 +119,7 @@ def randomizer(
else:
perm = sg > 15
# For specical setting, we only do standard_setting
hn = Hall(sg).hall_default if min(comp) < 1 else choice(Hall(sg, permutation=perm).hall_numbers)
hn = Hall(sg).hall_default if min(comp) < 1 else rng.choice(Hall(sg, permutation=perm).hall_numbers)
xtal.from_random(
3,
hn,
Expand All @@ -132,6 +133,7 @@ def randomizer(
torsions=torsions,
sites=sites,
use_hall=True,
#random_state=random_state,
)
if xtal.valid:
break
Expand Down

0 comments on commit c5e3111

Please sign in to comment.