From 173b966fb28feac149cb189d5304ffe83c17f2c3 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Date: Sun, 23 May 2021 14:57:52 -0500 Subject: [PATCH] debug-flag fix (#97) * debug-flag fix * move debug test to cons test --- amptorch/tests/consistency_test.py | 2 +- amptorch/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/amptorch/tests/consistency_test.py b/amptorch/tests/consistency_test.py index 2d704914..8120ac23 100644 --- a/amptorch/tests/consistency_test.py +++ b/amptorch/tests/consistency_test.py @@ -133,7 +133,7 @@ def test_energy_force_consistency(): "scaling": {"type": "normalize", "range": (-1, 1)}, }, "cmd": { - "debug": False, + "debug": True, "run_dir": "./", "seed": 1, "identifier": "test", diff --git a/amptorch/trainer.py b/amptorch/trainer.py index be229422..6cc77705 100644 --- a/amptorch/trainer.py +++ b/amptorch/trainer.py @@ -118,6 +118,7 @@ def load_dataset(self): self.target_scaler = self.train_dataset.target_scaler self.input_dim = self.train_dataset.input_dim self.val_split = self.config["dataset"].get("val_split", 0) + self.config["dataset"]["descriptor"] = descriptor_setup if not self.debug: normalizers = { "target": self.target_scaler, @@ -125,7 +126,6 @@ def load_dataset(self): } torch.save(normalizers, os.path.join(self.cp_dir, "normalizers.pt")) # clean/organize config - self.config["dataset"]["descriptor"] = descriptor_setup self.config["dataset"]["fp_length"] = self.input_dim torch.save(self.config, os.path.join(self.cp_dir, "config.pt")) print("Loading dataset: {} images".format(len(self.train_dataset)))