From 2167d8937e0aa4ea3feb0162e9fa7c86614f6e64 Mon Sep 17 00:00:00 2001 From: qzhu2017 Date: Fri, 6 Sep 2024 14:47:32 -0400 Subject: [PATCH] fix #271 by adding subprocess control and output logger for mpi debug --- pyxtal/db.py | 6 +-- pyxtal/interface/charmm.py | 78 ++++++++++++++++++++++++++------------ pyxtal/optimize/WFS.py | 15 +++++--- pyxtal/optimize/base.py | 26 ++++++++----- pyxtal/optimize/common.py | 7 +--- 5 files changed, 86 insertions(+), 46 deletions(-) diff --git a/pyxtal/db.py b/pyxtal/db.py index fc9d0ae1..6f8262eb 100644 --- a/pyxtal/db.py +++ b/pyxtal/db.py @@ -379,7 +379,7 @@ def copy(self, db_name, csd_codes): """ if db_name == self.db_name: raise RuntimeError("Cannot use the same db file for copy") - with connect(db_name) as db: + with connect(db_name, serial=True) as db: for csd_code in csd_codes: row_info = self.get_row_info(code=csd_code) (atom, kvp, data) = row_info @@ -603,7 +603,7 @@ def add_strucs_from_db(self, db_file, check=False, tol=0.1, freq=50): print(f"\nAdding new strucs from {db_file:s}") count = 0 - with connect(db_file) as db: + with connect(db_file, serial=True) as db: for row in db.select(): atoms = row.toatoms() xtal = pyxtal() @@ -1448,7 +1448,7 @@ def get_db_unique(self, db_name=None, prec=3): unique_props[prop_key] = (row.id, dof) ids = [unique_props[key][0] for key in unique_props.keys()] - with connect(db_name) as db: + with connect(db_name, serial=True) as db: for id in ids: row = self.db.get(id) kvp = {} diff --git a/pyxtal/interface/charmm.py b/pyxtal/interface/charmm.py index d416f0f2..c4fd7f7e 100644 --- a/pyxtal/interface/charmm.py +++ b/pyxtal/interface/charmm.py @@ -1,7 +1,7 @@ import contextlib import os import shutil - +import subprocess import numpy as np @@ -36,11 +36,14 @@ def __init__( output="charmm.log", dump="result.pdb", debug=False, + timeout=20, ): + self.errorE = 1e+5 + self.error = False if steps is None: steps = [2000, 1000] self.debug = debug - + self.timeout = timeout # check charmm Executable if shutil.which(exe) is None: raise BaseException(f"{exe} is not installed") @@ -101,16 +104,32 @@ def run(self, clean=True): os.chdir(self.folder) self.write() # ; print("write", time()-t0) - self.execute() # ; print("exe", time()-t0) - self.read() # ; print("read", self.structure.energy) + res = self.execute() # ; print("exe", time()-t0) + if res is not None: + self.read() # ; print("read", self.structure.energy) + else: + self.structure.energy = self.errorE + self.error = True if clean: self.clean() os.chdir(cwd) def execute(self): - cmd = self.exe + "<" + self.input + ">" + self.output - os.system(cmd) + cmd = self.exe + " < " + self.input + " > " + self.output + # os.system(cmd) + with open(os.devnull, 'w') as devnull: + try: + # Run the external command with a timeout + result = subprocess.run( + cmd, shell=True, timeout=self.timeout, check=True, stderr=devnull) + return result.returncode # Or handle the result as needed + except subprocess.CalledProcessError as e: + print(f"Command '{cmd}' failed with return code {e.returncode}.") + return None + except subprocess.TimeoutExpired: + print(f"External command {cmd} timed out.") + return None def clean(self): os.remove(self.input) @@ -129,7 +148,8 @@ def write(self): a, b, c, alpha, beta, gamma = lat.get_para(degree=True) ltype = lat.ltype - if ltype in ['trigonal', 'Trigonal']: ltype = 'hexagonal' + if ltype in ['trigonal', 'Trigonal']: + ltype = 'hexagonal' fft = self.FFTGrid(np.array([a, b, c])) @@ -148,14 +168,16 @@ def write(self): if self.atom_info is None: f.write(f"U0{site.type:d} ") else: - f.write("{:s} ".format(self.atom_info["resName"][site.type])) + f.write("{:s} ".format( + self.atom_info["resName"][site.type])) f.write("\ngenerate main first none last none setup warn\n") f.write("Read coor card free\n") f.write("* Residues coordinate\n*\n") f.write(f"{sum(atom_count):5d}\n") for i, site in enumerate(self.structure.mol_sites): - res_name = f"U0{site.type:d}" if self.atom_info is None else self.atom_info["resName"][site.type] + res_name = f"U0{site.type:d}" if self.atom_info is None else self.atom_info[ + "resName"][site.type] # reset lattice if needed (to move out later) site.lattice = lat @@ -179,12 +201,13 @@ def write(self): j + 1 + count, i + 1, res_name, label, *coord ) ) - # quickly check if + # quickly check if if abs(coord).max() > 500.0: print("Unexpectedly large input coordinates, stop and debug") print(self.structure) self.structure.to_file('bug.cif') - import sys; sys.exit() + import sys + sys.exit() f.write(f"write psf card name {self.psf:s}\n") f.write(f"write coor crd card name {self.crd:s}\n") @@ -204,26 +227,31 @@ def write(self): f.write("coor stat select all end\n") f.write("Crystal Define @shape @a @b @c @alpha @beta @gamma\n") site0 = self.structure.mol_sites[0] - f.write(f"Crystal Build cutoff 14.0 noperations {len(site0.wp.ops) - 1:d}\n") + f.write( + f"Crystal Build cutoff 14.0 noperations {len(site0.wp.ops) - 1:d}\n") for i, op in enumerate(site0.wp.ops): if i > 0: f.write(f"({op.as_xyz_str():s})\n") - f.write("image byres xcen ?xave ycen ?yave zcen ?zave sele resn LIG end\n") + f.write( + "image byres xcen ?xave ycen ?yave zcen ?zave sele resn LIG end\n") f.write("set 7 fswitch\n") f.write("set 8 atom\n") f.write("set 9 vatom\n") f.write("Update inbfrq 10 imgfrq 10 ihbfrq 10 -\n") - f.write("ewald pmewald lrc fftx {:d} ffty {:d} fftz {:d} -\n".format(*fft)) + f.write( + "ewald pmewald lrc fftx {:d} ffty {:d} fftz {:d} -\n".format(*fft)) f.write("kappa 0.34 order 6 CTOFNB 12.0 CUTNB 14.0 QCOR 1.0 -\n") f.write("@7 @8 @9 vfswitch !\n") f.write(f"mini {self.algo:s} nstep {self.steps[0]:d}\n") if len(self.steps) > 1: - f.write(f"mini {self.algo:s} lattice nstep {self.steps[1]:d} \n") + f.write( + f"mini {self.algo:s} lattice nstep {self.steps[1]:d} \n") if len(self.steps) > 2: f.write(f"mini {self.algo:s} nstep {self.steps[2]:d}\n") - f.write("coor conv SYMM FRAC ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") # + f.write( + "coor conv SYMM FRAC ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") # f.write(f"\nwrite coor pdb name {self.dump:s}\n") # f.write("*CELL : ?xtla ?xtlb ?xtlc ?xtlalpha ?xtlbeta ?xtlgamma\n") # f.write(f"*Z = {len(site0.wp):d}\n") @@ -270,7 +298,8 @@ def read(self): XYZ = [float(x) for x in xyz] positions.append(XYZ) except: - pass # print("Warning: BAD charmm output: " + line) + # print("Warning: BAD charmm output: " + line) + pass positions = np.array(positions) self.structure.energy *= Z @@ -283,7 +312,7 @@ def read(self): # if True: try: for _i, site in enumerate(self.structure.mol_sites): - coords = positions[count : count + len(site.molecule.mol)] + coords = positions[count: count + len(site.molecule.mol)] site.update(coords, self.structure.lattice) count += len(site.molecule.mol) # print("after relaxation : ", self.structure.lattice, "iter: ", self.structure.iter) @@ -292,7 +321,8 @@ def read(self): # print("after latticeopt : ", self.structure.lattice, self.structure.check_distance()); import sys; sys.exit() except: # molecular connectivity or lattice optimization - self.structure.energy = 10000 + self.structure.energy = self.errorE + self.error = True if self.debug: print("Unable to retrieve Structure after optimization") print("lattice", self.structure.lattice) @@ -304,12 +334,11 @@ def read(self): print("short distance pair", pairs) else: - self.structure.energy = 10000 + self.structure.energy = self.errorE + self.error = True if self.debug: print(self.structure) - import sys - - sys.exit() + import sys; sys.exit() def FFTGrid(self, ABC): """ @@ -616,7 +645,8 @@ def merge(self, rtf1=None, single=None): for a in res["ANGL"]: tmp = a.split("!") tmp1 = tmp[0].split() - a1, a2, a3 = str(i) + tmp1[1], str(i) + tmp1[2], str(i) + tmp1[3] + a1, a2, a3 = str( + i) + tmp1[1], str(i) + tmp1[2], str(i) + tmp1[3] a = f"ANGL {a1:6s} {a2:6s} {a3:6s} " if len(tmp) > 1: a += f"!{tmp[-1]:12s}" diff --git a/pyxtal/optimize/WFS.py b/pyxtal/optimize/WFS.py index f836e44a..ede4a794 100644 --- a/pyxtal/optimize/WFS.py +++ b/pyxtal/optimize/WFS.py @@ -147,6 +147,7 @@ def __init__( strs = self.full_str() self.logging.info(strs) print(strs) + print(f"Rank {self.rank} finish initialization {self.tag}") def full_str(self): s = str(self) @@ -168,17 +169,15 @@ def _run(self, pool=None): # Related to the FF optimization N_added = 0 success_rate = 0 + print(f"Rank {self.rank} starts WFS in {self.tag}") for gen in range(self.N_gen): self.generation = gen cur_xtals = None - print(f"Rank {self.rank} entering generation {gen} in {self.tag}") - + self.logging.info(f"Gen {gen} starts in Rank {self.rank} {self.tag}") if self.rank == 0: print(f"\nGeneration {gen:d} starts") - self.logging.info(f"Generation {gen:d} starts") t0 = time() - # Initialize cur_xtals = [(None, "Random")] * self.N_pop @@ -194,10 +193,11 @@ def _run(self, pool=None): # broadcast if self.use_mpi: cur_xtals = self.comm.bcast(cur_xtals, root=0) - #print(f"Rank {self.rank} after broadcast: current_xtals = {current_xtals}") + #self.logging.info(f"Rank {self.rank} gets {len(cur_xtals)} strucs {self.tag}") # Local optimization gen_results = self.local_optimization(cur_xtals, pool=pool) + self.logging.info(f"Rank {self.rank} finishes local_opt {self.tag}") prev_xtals = None if self.rank == 0: @@ -212,6 +212,8 @@ def _run(self, pool=None): if self.use_mpi: prev_xtals = self.comm.bcast(prev_xtals, root=0) + self.logging.info(f"Gen {gen} bcast in Rank {self.rank} {self.tag}") + # Update the FF parameters if necessary if self.ff_opt: N_added = self.update_ff_paramters(cur_xtals, engs, N_added) @@ -225,13 +227,16 @@ def _run(self, pool=None): elif self.ref_pxrd is not None: self.count_pxrd_match(cur_xtals, matches) + # quit the loop if self.use_mpi: quit = self.comm.bcast(quit, root=0) self.comm.Barrier() + self.logging.info(f"Gen {gen} Finish in Rank {self.rank} {self.tag}") # Ensure that all ranks exit if quit: + self.logging.info(f"Early Termination in Rank {self.rank} {self.tag}") return success_rate return success_rate diff --git a/pyxtal/optimize/base.py b/pyxtal/optimize/base.py index 73a5c97e..8aa3ded4 100644 --- a/pyxtal/optimize/base.py +++ b/pyxtal/optimize/base.py @@ -6,9 +6,9 @@ - QRS """ from __future__ import annotations -import multiprocessing from multiprocessing import Pool -from concurrent.futures import ProcessPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError +import signal import logging import os @@ -26,9 +26,6 @@ from pyxtal.optimize.common import optimizer_par, optimizer_single from pyxtal.lattice import Lattice from pyxtal.symmetry import Group -import signal -import gc - def run_optimizer_with_timeout(args): """ @@ -174,6 +171,8 @@ def __init__( # Generation and Optimization self.workdir = workdir self.log_file = self.workdir + "/loginfo" + if self.rank > 0: self.log_file += f"-{self.rank}" + self.skip_ani = skip_ani # self.randomizer = randomizer # self.optimizer = optimizer @@ -359,6 +358,9 @@ def run(self, ref_pmg=None, ref_pxrd=None): strs = f"{self.name:s} {self.workdir} COMPLETED " strs += f"in {t:.1f} mins {self.N_struc:d} strucs." print(strs) + + if self.use_mpi: + self.comm.Barrier() return results def select_xtals(self, ref_xtals, ids, N_max): @@ -961,6 +963,7 @@ def local_optimization(self, xtals, qrs=False, pool=None): elif self.ncpu == 1: return self.local_optimization_serial(xtals, qrs) else: + print(f"Local optimization by multi-threads {ncpu}") return self.local_optimization_mproc(xtals, self.ncpu, qrs=qrs, pool=pool) def local_optimization_serial(self, xtals, qrs=False): @@ -997,21 +1000,24 @@ def local_optimization_mpi(self, xtals, qrs, pool): # Distribute args_lists across available ranks (processes) local_xtals = xtals[self.rank::self.size] + local_ids = list(range(self.N_pop))[self.rank::self.size] - # Determine the number of cores available on the node + # Call local_optimization_mproc + self.logging.info(f"Rank {self.rank} gets {len(local_xtals)} strucs") results = self.local_optimization_mproc(local_xtals, self.ncpu, local_ids, qrs, pool) # Synchronize before gathering + self.logging.info(f"Rank {self.rank} finish local_optimization_mproc") self.comm.Barrier() # Gather all results at the root process - #print(f"Rank {self.rank} in MPI_Gather at gen {gen} {time()-t0}") + self.logging.info(f"Rank {self.rank} in MPI_Gather at gen {gen}") all_results = self.comm.gather(results, root=0) - #print(f"Rank {self.rank} done MPI_Gather at gen {gen} {time()-t0}") + self.logging.info(f"Rank {self.rank} done MPI_Gather at gen {gen}") # If root process, process the results gen_results = None @@ -1023,6 +1029,7 @@ def local_optimization_mpi(self, xtals, qrs, pool): gen_results[id] = (id, xtal, match) # Broadcast + self.logging.info(f"Rank {self.rank} MPI_bcast at gen {gen}") gen_results = self.comm.bcast(gen_results, root=0) return gen_results @@ -1037,7 +1044,6 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None): ids (list): qrs (bool): Force mutation or not (related to QRS) """ - print("Local optimization enabled by multi-threads", ncpu) gen = self.generation t0 = time() args = self._get_local_optimization_args() @@ -1047,6 +1053,7 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None): N_cycle = int(np.ceil(len(xtals) / ncpu)) args_lists = [] + # Assign args for i in range(ncpu): id1 = i * N_cycle id2 = min([id1 + N_cycle, len(xtals)]) @@ -1059,6 +1066,7 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None): my_args = [_xtals, _ids, mutates, job_tags, *args, self.timeout] args_lists.append(tuple(my_args)) + self.logging.info(f"Rank {self.rank} assign args in local_opt_mproc") gen_results = [] for result in pool.imap_unordered(process_task, args_lists): if result is not None: diff --git a/pyxtal/optimize/common.py b/pyxtal/optimize/common.py index 01250612..9c210fdd 100644 --- a/pyxtal/optimize/common.py +++ b/pyxtal/optimize/common.py @@ -241,7 +241,7 @@ def optimizer( calc.run() # print("Debug", calc.optimized); import sys; sys.exit() # only count good struc - if calc.structure.energy < 9999: + if not calc.error: calc = CHARMM(calc.structure, tag, steps=steps, atom_info=atom_info) calc.run()#clean=False) @@ -623,12 +623,9 @@ def load_reference_from_db(db_name, code=None): args.append((smile, wdir, sg, tag, chm_info, comp, lat, pmg0, wt, spg, N_torsion)) return args - - - if __name__ == "__main__": + import pymatgen.analysis.structure_matcher as sm - from pyxtal.db import database w_dir = "tmp"