From 5bbdb6a5a4d65e665a1323d69423c737493fdefc Mon Sep 17 00:00:00 2001 From: Ben Nebgen Date: Fri, 13 Oct 2023 12:43:02 -0600 Subject: [PATCH] Reverted merge mistakes with database.py --- hippynn/databases/database.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index 3466d008..e9375b49 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -191,10 +191,10 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample if not self.splitting_completed: raise ValueError("Database has not yet been split.") - if split_type in ("train", "valid", "test"): - data = [self.splits[split_type][k] for k in self.var_list] - else: - raise ValueError("Datatype {} Invalid. Must be one of 'train','valid','test'".format(split_type)) + if split_type not in self.splits: + raise ValueError(f"Split {split_type} Invalid. Current splits:{list(self.splits.keys())}") + + data = [self.splits[split_type][k] for k in self.var_list] if evaluation_mode == "train": if split_type != "train": @@ -205,7 +205,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample elif evaluation_mode == "eval": shuffle = False else: - raise ValueError("Evaluation_mode ({}) must be one of 'train' or 'eval'") + raise ValueError(f"Evaluation_mode ({evaluation_mode}) must be one of 'train' or 'eval'") dataset = NamedTensorDataset(self.var_list, *data) if subsample: