diff --git a/flare/gp.py b/flare/gp.py index 6b3e3b560..0e5d9ae15 100644 --- a/flare/gp.py +++ b/flare/gp.py @@ -74,8 +74,6 @@ def update_db(self, struc: Structure, forces: list, # create numpy array of training labels self.training_labels_np = self.force_list_to_np(self.training_labels) - self.set_L_alpha() - def add_one_env(self, env: AtomicEnvironment, force: np.array, train: bool = False, **kwargs): """ @@ -176,6 +174,11 @@ def predict(self, x_t: AtomicEnvironment, d: int) -> [float, float]: # Kernel vector allows for evaluation of At. Env. k_v = self.get_kernel_vector(x_t, d) + # Guarantee that alpha is up to date with training set + if self.alpha is None or 3 * len(self.training_data) != len( + self.alpha): + self.set_L_alpha() + # get predictive mean pred_mean = np.matmul(k_v, self.alpha) diff --git a/flare/gp_from_aimd.py b/flare/gp_from_aimd.py index 6f2d6ff4e..d42f072da 100644 --- a/flare/gp_from_aimd.py +++ b/flare/gp_from_aimd.py @@ -156,6 +156,7 @@ def pre_run(self): train_atoms.append(atom) self.update_gp_and_print(frame, train_atoms, train=False) + self.gp.set_L_alpha() # These conditions correspond to if either the GP was never trained # or if data was added to it during the pre-run. @@ -207,6 +208,8 @@ def run(self): if self.train_count < self.max_trains: self.train_gp() + else: + self.gp.set_L_alpha() self.output.conclude_run() @@ -234,7 +237,6 @@ def update_gp_and_print(self, frame: Structure, train_atoms: List[int], # update gp model self.gp.update_db(frame, frame.forces, custom_range=train_atoms) - self.gp.set_L_alpha() if train: self.train_gp() diff --git a/flare/otf_parser.py b/flare/otf_parser.py index 0bfda852e..44c9d816c 100644 --- a/flare/otf_parser.py +++ b/flare/otf_parser.py @@ -44,7 +44,7 @@ def __init__(self, filename, calculate_energy=False): def make_gp(self, cell=None, kernel=None, kernel_grad=None, algo=None, call_no=None, cutoffs=None, hyps=None, init_gp=None, - energy_force_kernel=None, hyp_no=None): + energy_force_kernel=None, hyp_no=None, par=True): if init_gp is None: # Use run's values as extracted from header @@ -71,7 +71,8 @@ def make_gp(self, cell=None, kernel=None, kernel_grad=None, algo=None, gp_model = \ gp.GaussianProcess(kernel, kernel_grad, gp_hyps, cutoffs, opt_algorithm=algo, - energy_force_kernel=energy_force_kernel) + energy_force_kernel=energy_force_kernel, + par=par) else: gp_model = init_gp call_no = len(self.gp_position_list)