Skip to content

Commit

Permalink
Merge pull request #71 from mir-group/bug/jon/gp-update
Browse files Browse the repository at this point in the history
Bug/jon/gp update
  • Loading branch information
jonpvandermause authored Sep 25, 2019
2 parents a0b6ba8 + 059ec02 commit cc68972
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
7 changes: 5 additions & 2 deletions flare/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def update_db(self, struc: Structure, forces: list,
# create numpy array of training labels
self.training_labels_np = self.force_list_to_np(self.training_labels)

self.set_L_alpha()

def add_one_env(self, env: AtomicEnvironment,
force: np.array, train: bool = False, **kwargs):
"""
Expand Down Expand Up @@ -176,6 +174,11 @@ def predict(self, x_t: AtomicEnvironment, d: int) -> [float, float]:
# Kernel vector allows for evaluation of At. Env.
k_v = self.get_kernel_vector(x_t, d)

# Guarantee that alpha is up to date with training set
if self.alpha is None or 3 * len(self.training_data) != len(
self.alpha):
self.set_L_alpha()

# get predictive mean
pred_mean = np.matmul(k_v, self.alpha)

Expand Down
4 changes: 3 additions & 1 deletion flare/gp_from_aimd.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def pre_run(self):
train_atoms.append(atom)

self.update_gp_and_print(frame, train_atoms, train=False)
self.gp.set_L_alpha()

# These conditions correspond to if either the GP was never trained
# or if data was added to it during the pre-run.
Expand Down Expand Up @@ -207,6 +208,8 @@ def run(self):

if self.train_count < self.max_trains:
self.train_gp()
else:
self.gp.set_L_alpha()

self.output.conclude_run()

Expand Down Expand Up @@ -234,7 +237,6 @@ def update_gp_and_print(self, frame: Structure, train_atoms: List[int],

# update gp model
self.gp.update_db(frame, frame.forces, custom_range=train_atoms)
self.gp.set_L_alpha()

if train:
self.train_gp()
Expand Down
5 changes: 3 additions & 2 deletions flare/otf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, filename, calculate_energy=False):

def make_gp(self, cell=None, kernel=None, kernel_grad=None, algo=None,
call_no=None, cutoffs=None, hyps=None, init_gp=None,
energy_force_kernel=None, hyp_no=None):
energy_force_kernel=None, hyp_no=None, par=True):

if init_gp is None:
# Use run's values as extracted from header
Expand All @@ -71,7 +71,8 @@ def make_gp(self, cell=None, kernel=None, kernel_grad=None, algo=None,
gp_model = \
gp.GaussianProcess(kernel, kernel_grad, gp_hyps,
cutoffs, opt_algorithm=algo,
energy_force_kernel=energy_force_kernel)
energy_force_kernel=energy_force_kernel,
par=par)
else:
gp_model = init_gp
call_no = len(self.gp_position_list)
Expand Down

0 comments on commit cc68972

Please sign in to comment.