From 130186d782e2fb3e1f5d4877497d82101a7a9913 Mon Sep 17 00:00:00 2001 From: Yu Xie Date: Mon, 1 Aug 2022 09:46:08 -0400 Subject: [PATCH] change Step to step in wandb --- flare/learners/otf.py | 64 ++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/flare/learners/otf.py b/flare/learners/otf.py index a90ffc3d2..22afacc2c 100644 --- a/flare/learners/otf.py +++ b/flare/learners/otf.py @@ -409,15 +409,17 @@ def run(self): # wandb log mae if self.wandb_log is not None: - wandb.log({ - "Step": self.curr_step, - "e_mae": e_mae, - "e_mav": e_mav, - "f_mae": f_mae, - "f_mav": f_mav, - "s_mae": s_mae, - "s_mav": s_mav, - }) + wandb.log( + { + "dft_e_mae": e_mae, + "dft_e_mav": e_mav, + "dft_f_mae": f_mae, + "dft_f_mav": f_mav, + "dft_s_mae": s_mae, + "dft_s_mav": s_mav, + }, + step = self.curr_step, + ) # write gp forces if counter >= self.skip and not self.dft_step: @@ -772,17 +774,21 @@ def record_state(self): # wandb log mae if self.wandb_log is not None: - wandb.log({ - "Step": self.curr_step, - "temperature": self.temperature, - "ke": self.KE, - "pe": self.atoms.get_potential_energy(), - }) + wandb.log( + { + "temperature": self.temperature, + "ke": self.KE, + "pe": self.atoms.get_potential_energy(), + }, + step = self.curr_step, + ) if "stds" in self.atoms.calc.results: - wandb.log({ - "Step": self.curr_step, - "maxunc": np.max(np.abs(self.atoms.calc.results["stds"])), - }) + wandb.log( + { + "maxunc": np.max(np.abs(self.atoms.calc.results["stds"])), + }, + step = self.curr_step, + ) if self.md_engine == "Fake" and not self.dft_step: tic = time.time() @@ -801,15 +807,17 @@ def record_state(self): # wandb log mae if self.wandb_log is not None: - wandb.log({ - "Step": self.curr_step, - "e_mae": e_mae, - "e_mav": e_mav, - "f_mae": f_mae, - "f_mav": f_mav, - "s_mae": s_mae, - "s_mav": s_mav, - }) + wandb.log( + { + "e_mae": e_mae, + "e_mav": e_mav, + "f_mae": f_mae, + "f_mav": f_mav, + "s_mae": s_mae, + "s_mav": s_mav, + }, + step = self.curr_step, + ) def record_dft_data(self, structure, target_atoms):