diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index 68e514b8d..3e05fea50 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -177,7 +177,7 @@ def setup(self, stage: Optional[str] = None) -> None: kwcoco.CocoDataset(self.hparams.coco_train_poses), self.hparams.target_framerate, ) - if stage == "validate" and not self.data_val: + if (stage == "validate" or stage == "fit") and not self.data_val: self.data_val.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_validation_activities), kwcoco.CocoDataset(self.hparams.coco_validation_objects),