Skip to content

Commit

Permalink
fix otf logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpvandermause committed Jun 16, 2020
1 parent 43dc506 commit 8544b9f
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions flare/otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self,
prev_pos_init: 'ndarray' = None,
rescale_steps: List[int] = [], rescale_temps: List[int] = [],
# flare args
gp: gp.GaussianProcess=None,
gp: gp.GaussianProcess = None,
calculate_energy: bool = False,
write_model: int = 0,
# otf args
Expand All @@ -103,8 +103,8 @@ def __init__(self,
max_atoms_added: int = 1, freeze_hyps: int = 10,
# dft args
force_source: str = "qe",
npool: int = None, mpi: str = "srun", dft_loc: str=None,
dft_input: str=None, dft_output='dft.out', dft_kwargs=None,
npool: int = None, mpi: str = "srun", dft_loc: str = None,
dft_input: str = None, dft_output='dft.out', dft_kwargs=None,
store_dft_output: Tuple[Union[str, List[str]],str] = None,
# par args
n_cpus: int = 1,
Expand Down Expand Up @@ -157,11 +157,11 @@ def __init__(self,
self.dft_count = 0

# set pred function
if (n_cpus>1 and gp.per_atom_par and gp.parallel) and not calculate_energy:
if (n_cpus > 1 and gp.per_atom_par and gp.parallel) and not calculate_energy:
self.pred_func = predict.predict_on_structure_par
elif not calculate_energy:
self.pred_func = predict.predict_on_structure
elif (n_cpus>1 and gp.per_atom_par and gp.parallel):
elif (n_cpus > 1 and gp.per_atom_par and gp.parallel):
self.pred_func = predict.predict_on_structure_par_en
else:
self.pred_func = predict.predict_on_structure_en
Expand Down Expand Up @@ -201,7 +201,8 @@ def run(self):

while self.curr_step < self.number_of_steps:
# run DFT and train initial model if first step and DFT is on
if self.curr_step == 0 and self.std_tolerance != 0 and len(self.gp.training_data)==0:
if self.curr_step == 0 and self.std_tolerance != 0 and \
len(self.gp.training_data) == 0:

self.initialize_train()
new_pos = self.md_step()
Expand All @@ -210,21 +211,18 @@ def run(self):

# after step 1, try predicting with GP model
else:

# compute forces and stds with GP
self.dft_step = False
self.compute_properties()
new_pos = self.md_step()

# get max uncertainty atoms
std_in_bound, target_atoms = \
is_std_in_bound(self.std_tolerance,
self.gp.hyps[-1], self.structure,
self.max_atoms_added)

if std_in_bound:
new_pos = self.md_step()

else:
if not std_in_bound:
# record GP forces
self.update_temperature(new_pos)
self.record_state()
Expand Down Expand Up @@ -261,7 +259,6 @@ def run(self):
if self.write_model >= 1:
self.gp.write_model(self.output_name+"_model")


def initialize_train(self):
# call dft and update positions
self.run_dft()
Expand All @@ -270,23 +267,20 @@ def initialize_train(self):
# make initial gp model and predict forces
self.update_gp(self.init_atoms, dft_frcs)


def compute_properties(self):
'''
In ASE-OTF, it will be replaced by subclass method
'''
self.gp.check_L_alpha()
self.pred_func(self.structure, self.gp, self.n_cpus)


def md_step(self):
'''
In ASE-OTF, it will be replaced by subclass method
'''
return md.update_positions(self.dt, self.noa,
self.structure)


def run_dft(self):
"""Calculates DFT forces on atoms in the current structure.
Expand Down

0 comments on commit 8544b9f

Please sign in to comment.