Skip to content

Commit

Permalink
Reverted merge mistakes with database.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bnebgen-LANL committed Oct 13, 2023
1 parent 019c314 commit 5bbdb6a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions hippynn/databases/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample
if not self.splitting_completed:
raise ValueError("Database has not yet been split.")

if split_type in ("train", "valid", "test"):
data = [self.splits[split_type][k] for k in self.var_list]
else:
raise ValueError("Datatype {} Invalid. Must be one of 'train','valid','test'".format(split_type))
if split_type not in self.splits:
raise ValueError(f"Split {split_type} Invalid. Current splits:{list(self.splits.keys())}")

data = [self.splits[split_type][k] for k in self.var_list]

if evaluation_mode == "train":
if split_type != "train":
Expand All @@ -205,7 +205,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample
elif evaluation_mode == "eval":
shuffle = False
else:
raise ValueError("Evaluation_mode ({}) must be one of 'train' or 'eval'")
raise ValueError(f"Evaluation_mode ({evaluation_mode}) must be one of 'train' or 'eval'")

dataset = NamedTensorDataset(self.var_list, *data)
if subsample:
Expand Down

0 comments on commit 5bbdb6a

Please sign in to comment.