From 86eaaafd9eaa834e9f8952f985a685461529677d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 19 Apr 2021 15:16:30 +0100 Subject: [PATCH] update --- tests/data/test_data_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 6e065bae31..9d15247cf2 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -598,7 +598,7 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: batch = batch[0] assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) - return False + return [False] @staticmethod def fn_test_load_data() -> List[torch.Tensor]: @@ -649,6 +649,8 @@ def training_step(self, batch, batch_idx): assert batch is None def validation_step(self, batch, batch_idx): + if isinstance(batch, list): + batch = batch[0] assert batch is False def test_step(self, batch, batch_idx):