diff --git a/amptorch/tests/consistency_test.py b/amptorch/tests/consistency_test.py index 2d704914..8120ac23 100644 --- a/amptorch/tests/consistency_test.py +++ b/amptorch/tests/consistency_test.py @@ -133,7 +133,7 @@ def test_energy_force_consistency(): "scaling": {"type": "normalize", "range": (-1, 1)}, }, "cmd": { - "debug": False, + "debug": True, "run_dir": "./", "seed": 1, "identifier": "test", diff --git a/amptorch/trainer.py b/amptorch/trainer.py index be229422..6cc77705 100644 --- a/amptorch/trainer.py +++ b/amptorch/trainer.py @@ -118,6 +118,7 @@ def load_dataset(self): self.target_scaler = self.train_dataset.target_scaler self.input_dim = self.train_dataset.input_dim self.val_split = self.config["dataset"].get("val_split", 0) + self.config["dataset"]["descriptor"] = descriptor_setup if not self.debug: normalizers = { "target": self.target_scaler, @@ -125,7 +126,6 @@ def load_dataset(self): } torch.save(normalizers, os.path.join(self.cp_dir, "normalizers.pt")) # clean/organize config - self.config["dataset"]["descriptor"] = descriptor_setup self.config["dataset"]["fp_length"] = self.input_dim torch.save(self.config, os.path.join(self.cp_dir, "config.pt")) print("Loading dataset: {} images".format(len(self.train_dataset)))