From 8544b9fd1f11eb1351fd29fd47457caa9d513ef4 Mon Sep 17 00:00:00 2001 From: Jonathan Vandermause Date: Tue, 16 Jun 2020 16:28:06 -0400 Subject: [PATCH] fix otf logic --- flare/otf.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/flare/otf.py b/flare/otf.py index e4d6035c1..3649cc739 100644 --- a/flare/otf.py +++ b/flare/otf.py @@ -93,7 +93,7 @@ def __init__(self, prev_pos_init: 'ndarray' = None, rescale_steps: List[int] = [], rescale_temps: List[int] = [], # flare args - gp: gp.GaussianProcess=None, + gp: gp.GaussianProcess = None, calculate_energy: bool = False, write_model: int = 0, # otf args @@ -103,8 +103,8 @@ def __init__(self, max_atoms_added: int = 1, freeze_hyps: int = 10, # dft args force_source: str = "qe", - npool: int = None, mpi: str = "srun", dft_loc: str=None, - dft_input: str=None, dft_output='dft.out', dft_kwargs=None, + npool: int = None, mpi: str = "srun", dft_loc: str = None, + dft_input: str = None, dft_output='dft.out', dft_kwargs=None, store_dft_output: Tuple[Union[str, List[str]],str] = None, # par args n_cpus: int = 1, @@ -157,11 +157,11 @@ def __init__(self, self.dft_count = 0 # set pred function - if (n_cpus>1 and gp.per_atom_par and gp.parallel) and not calculate_energy: + if (n_cpus > 1 and gp.per_atom_par and gp.parallel) and not calculate_energy: self.pred_func = predict.predict_on_structure_par elif not calculate_energy: self.pred_func = predict.predict_on_structure - elif (n_cpus>1 and gp.per_atom_par and gp.parallel): + elif (n_cpus > 1 and gp.per_atom_par and gp.parallel): self.pred_func = predict.predict_on_structure_par_en else: self.pred_func = predict.predict_on_structure_en @@ -201,7 +201,8 @@ def run(self): while self.curr_step < self.number_of_steps: # run DFT and train initial model if first step and DFT is on - if self.curr_step == 0 and self.std_tolerance != 0 and len(self.gp.training_data)==0: + if self.curr_step == 0 and self.std_tolerance != 0 and \ + len(self.gp.training_data) == 0: self.initialize_train() new_pos = self.md_step() @@ -210,10 +211,10 @@ def run(self): # after step 1, try predicting with GP model else: - # compute forces and stds with GP self.dft_step = False self.compute_properties() + new_pos = self.md_step() # get max uncertainty atoms std_in_bound, target_atoms = \ @@ -221,10 +222,7 @@ def run(self): self.gp.hyps[-1], self.structure, self.max_atoms_added) - if std_in_bound: - new_pos = self.md_step() - - else: + if not std_in_bound: # record GP forces self.update_temperature(new_pos) self.record_state() @@ -261,7 +259,6 @@ def run(self): if self.write_model >= 1: self.gp.write_model(self.output_name+"_model") - def initialize_train(self): # call dft and update positions self.run_dft() @@ -270,7 +267,6 @@ def initialize_train(self): # make initial gp model and predict forces self.update_gp(self.init_atoms, dft_frcs) - def compute_properties(self): ''' In ASE-OTF, it will be replaced by subclass method @@ -278,7 +274,6 @@ def compute_properties(self): self.gp.check_L_alpha() self.pred_func(self.structure, self.gp, self.n_cpus) - def md_step(self): ''' In ASE-OTF, it will be replaced by subclass method @@ -286,7 +281,6 @@ def md_step(self): return md.update_positions(self.dt, self.noa, self.structure) - def run_dft(self): """Calculates DFT forces on atoms in the current structure.