diff --git a/flare/otf.py b/flare/otf.py index 9b71e25c3..019236ae9 100644 --- a/flare/otf.py +++ b/flare/otf.py @@ -192,17 +192,19 @@ def __init__( self.last_dft_step = 0 # Set the prediction function based on user inputs. + self.calculate_energy = calculate_energy + self.calculate_efs = calculate_efs # Force only prediction. if (n_cpus > 1 and gp.per_atom_par and gp.parallel) and not ( - calculate_energy or calculate_efs + self.calculate_energy or self.calculate_efs ): self.pred_func = predict.predict_on_structure_par - elif not (calculate_energy or calculate_efs): + elif not (self.calculate_energy or self.calculate_efs): self.pred_func = predict.predict_on_structure # Energy and force prediction. - elif (n_cpus > 1 and gp.per_atom_par and gp.parallel) and not (calculate_efs): + elif (n_cpus > 1 and gp.per_atom_par and gp.parallel) and not (self.calculate_efs): self.pred_func = predict.predict_on_structure_par_en - elif not calculate_efs: + elif not self.calculate_efs: self.pred_func = predict.predict_on_structure_en # Energy, force, and stress prediction. elif n_cpus > 1 and gp.per_atom_par and gp.parallel: