diff --git a/CHANGELOG.md b/CHANGELOG.md index 804b60dd19..cc038e9a92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) + +- Fixed examples for BaaL integration by removing usage of `on__dataloader` hooks (removed in PL 1.7.0) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) + +- Fixed examples for BaaL integration for the case when `probabilities` list is empty ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) + +- Fixed a bug where collate functions were not being attached successfully after the `DataLoader` is initialized (in PL 1.7.0 changing attributes after initialization doesn't do anything) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410)) + - Fixed a bug where grayscale images were not properly converted to RGB when loaded. ([#1394](https://github.com/PyTorchLightning/lightning-flash/pull/1394)) - Fixed a bug where size of mask for instance segmentation doesn't match to size of original image. ([#1353](https://github.com/PyTorchLightning/lightning-flash/pull/1353)) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index fc2a43ad03..bf06ae1b38 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -110,6 +110,16 @@ def _wrap_collate_fn(collate_fn, samples): DataKeys.METADATA: metadata, } + def _update_collate_fn_dataloader(self, new_collate_fn, data_loader): + # Starting PL 1.7.0 - changing attributes after the DataLoader is initialized - will not work + # So we manually update the collate_fn for the dataloader, for now. + new_kwargs = getattr(data_loader, "__pl_saved_kwargs", None) + if new_kwargs: + new_kwargs["collate_fn"] = new_collate_fn + setattr(data_loader, "__pl_saved_kwargs", new_kwargs) + data_loader.collate_fn = new_collate_fn + return data_loader + def process_train_dataset( self, dataset: InputBase, @@ -134,12 +144,16 @@ def process_train_dataset( persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + data_loader = self._update_collate_fn_dataloader( + functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader + ) input_transform = input_transform or self.input_transform if input_transform is not None: input_transform.inject_collate_fn(data_loader.collate_fn) - data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TRAINING, input_transform) + data_loader = self._update_collate_fn_dataloader( + create_worker_input_transform_processor(RunningStage.TRAINING, input_transform), data_loader + ) return data_loader def process_val_dataset( @@ -166,12 +180,16 @@ def process_val_dataset( persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + data_loader = self._update_collate_fn_dataloader( + functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader + ) input_transform = input_transform or self.input_transform if input_transform is not None: input_transform.inject_collate_fn(data_loader.collate_fn) - data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform) + data_loader = self._update_collate_fn_dataloader( + create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform), data_loader + ) return data_loader def process_test_dataset( @@ -198,12 +216,16 @@ def process_test_dataset( persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + data_loader = self._update_collate_fn_dataloader( + functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader + ) input_transform = input_transform or self.input_transform if input_transform is not None: input_transform.inject_collate_fn(data_loader.collate_fn) - data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TESTING, input_transform) + data_loader = self._update_collate_fn_dataloader( + create_worker_input_transform_processor(RunningStage.TESTING, input_transform), data_loader + ) return data_loader def process_predict_dataset( @@ -230,12 +252,16 @@ def process_predict_dataset( persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + data_loader = self._update_collate_fn_dataloader( + functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader + ) input_transform = input_transform or self.input_transform if input_transform is not None: input_transform.inject_collate_fn(data_loader.collate_fn) - data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform) + data_loader = self._update_collate_fn_dataloader( + create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform), data_loader + ) return data_loader def training_step(self, batch, batch_idx) -> Any: diff --git a/flash/core/utilities/lightning_cli.py b/flash/core/utilities/lightning_cli.py index 99a5e11153..ecfaf52589 100644 --- a/flash/core/utilities/lightning_cli.py +++ b/flash/core/utilities/lightning_cli.py @@ -97,7 +97,8 @@ def add_lightning_class_args( lightning_class = class_from_function(lightning_class) if inspect.isclass(lightning_class) and issubclass( - cast(type, lightning_class), (Trainer, LightningModule, LightningDataModule, Callback) + cast(type, lightning_class), + (Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase), ): if issubclass(cast(type, lightning_class), Callback): self.callback_keys.append(nested_key) diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index f427b13f60..f1ca83799d 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -180,7 +180,7 @@ def label(self, probabilities: List[Tensor] = None, indices=None): raise MisconfigurationException( "The `probabilities` and `indices` are mutually exclusive, pass only of one them." ) - if probabilities is not None: + if probabilities is not None and len(probabilities) != 0: probabilities = torch.cat([p[0].unsqueeze(0) for p in probabilities], dim=0) uncertainties = self.heuristic.get_uncertainties(probabilities) indices = np.argsort(uncertainties) diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py index 2a1142a011..15c0615f21 100644 --- a/flash/image/classification/integrations/baal/loop.py +++ b/flash/image/classification/integrations/baal/loop.py @@ -91,7 +91,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: assert isinstance(self.trainer.datamodule, ActiveLearningDataModule) if self._datamodule_state_dict is not None: self.trainer.datamodule.load_state_dict(self._datamodule_state_dict) - self.trainer.predict_loop._return_predictions = True + self.trainer.predict_loop.return_predictions = True self._lightning_module = self.trainer.lightning_module self._model_state_dict = deepcopy(self._lightning_module.state_dict()) self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration) @@ -165,21 +165,18 @@ def _connect(self, model: LightningModule): def _reset_fitting(self): self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True - self.trainer.lightning_module.on_train_dataloader() self._connect(self._lightning_module) self.fit_loop.epoch_progress = Progress() def _reset_predicting(self): self.trainer.state.fn = TrainerFn.PREDICTING self.trainer.predicting = True - self.trainer.lightning_module.on_predict_dataloader() self._connect(self.inference_model) def _reset_testing(self): self.trainer.state.fn = TrainerFn.TESTING self.trainer.state.status = TrainerStatus.RUNNING self.trainer.testing = True - self.trainer.lightning_module.on_test_dataloader() self._connect(self._lightning_module) def _reset_dataloader_for_stage(self, running_state: RunningStage): diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 5b19a69d95..cc2ee0203b 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -154,8 +154,8 @@ def on_train_start(self) -> None: def on_train_epoch_end(self) -> None: self.adapter.on_train_epoch_end() - def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None: + self.adapter.on_train_batch_end(outputs, batch, batch_idx, *args) @classmethod @requires("image", "vissl", "fairscale") diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py index 9176883f59..e8ecd801a9 100644 --- a/flash/image/embedding/vissl/hooks.py +++ b/flash/image/embedding/vissl/hooks.py @@ -96,7 +96,7 @@ def on_train_start(self) -> None: for hook in self.hooks: hook.on_start(self.task) - def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None: self.task.iteration += 1 def on_train_epoch_end(self) -> None: diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index ed1c2f386f..34fed43bee 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -13,3 +13,5 @@ fairscale # pin PL for testing, remove when fastface is updated pytorch-lightning<1.5.0 torchmetrics<0.8.0 # pinned PL so we force a compatible TM version +# effdet had an issue with PL 1.12, and icevision doesn't support effdet's latest version yet (0.3.0) +torch<1.12 diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index 8ee6797b17..9e0026007e 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -102,8 +102,10 @@ def _test_jit_trace(self, tmpdir): path = os.path.join(tmpdir, "test.pt") model = self.instantiated_task + trainer = self.instantiated_trainer model.eval() + model.trainer = trainer model = torch.jit.trace(model, self.example_forward_input) torch.jit.save(model, path) @@ -117,8 +119,10 @@ def _test_jit_script(self, tmpdir): path = os.path.join(tmpdir, "test.pt") model = self.instantiated_task + trainer = self.instantiated_trainer model.eval() + model.trainer = trainer model = torch.jit.script(model) torch.jit.save(model, path) @@ -261,10 +265,17 @@ class TaskTester(metaclass=TaskTesterMeta): "test_cli": [pytest.mark.parametrize("extra_args", [[]])], } + trainer_args: Tuple = () + trainer_kwargs: Dict = {} + @property def instantiated_task(self): return self.task(*self.task_args, **self.task_kwargs) + @property + def instantiated_trainer(self): + return flash.Trainer(*self.trainer_args, **self.trainer_kwargs) + @property def example_forward_input(self): pass