From 8de809e346e63ff2d5b653ff7d94f0de17da3afd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 8 Nov 2021 12:36:10 +0000 Subject: [PATCH] Clean `Output` (#939) --- CHANGELOG.md | 6 +++ docs/source/api/data.rst | 1 - flash/audio/speech_recognition/model.py | 2 +- flash/core/data/io/output.py | 38 +--------------- flash/core/model.py | 11 ++--- flash/core/utilities/types.py | 2 +- flash/image/classification/model.py | 3 +- flash/image/detection/model.py | 3 +- flash/image/segmentation/model.py | 2 +- flash/pointcloud/detection/model.py | 2 +- flash/pointcloud/segmentation/model.py | 2 +- flash/tabular/classification/model.py | 2 +- flash/tabular/regression/model.py | 2 +- flash/template/classification/model.py | 2 +- flash/text/classification/model.py | 2 +- flash/video/classification/model.py | 3 +- tests/core/data/io/test_output.py | 58 ++----------------------- 17 files changed, 28 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 457eddcc2a..6852f83f3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +### Removed + +- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939)) + +- Removed `Output.enable` and `Output.disable` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939)) + ## [0.5.2] - 2021-11-05 diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 00e35b8529..4e46fd1434 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -16,7 +16,6 @@ _________________________ :template: classtemplate.rst ~flash.core.data.io.output.Output - ~flash.core.data.io.output.OutputMapping flash.core.data.auto_dataset ____________________________ diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 9d895279d4..4370aaf13f 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -41,7 +41,7 @@ class SpeechRecognition(Task): learning_rate: Learning rate to use for training, defaults to ``1e-5``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. - output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES diff --git a/flash/core/data/io/output.py b/flash/core/data/io/output.py index 18d50b73a1..ce2cd9ef4b 100644 --- a/flash/core/data/io/output.py +++ b/flash/core/data/io/output.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Mapping +from typing import Any import torch -import flash from flash.core.data.properties import Properties from flash.core.data.utils import convert_to_modules @@ -24,18 +23,6 @@ class Output(Properties): """An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which is used to convert the model output into the desired output format when predicting.""" - def __init__(self): - super().__init__() - self._is_enabled = True - - def enable(self): - """Enable output transformation.""" - self._is_enabled = True - - def disable(self): - """Disable output transformation.""" - self._is_enabled = False - @staticmethod def transform(sample: Any) -> Any: """Convert the given sample into the desired output format. @@ -49,28 +36,7 @@ def transform(sample: Any) -> Any: return sample def __call__(self, sample: Any) -> Any: - if self._is_enabled: - return self.transform(sample) - return sample - - -class OutputMapping(Output): - """If the model output is a dictionary, then the :class:`.OutputMapping` enables each entry in the dictionary - to be passed to it's own :class:`.Output`.""" - - def __init__(self, outputs: Mapping[str, Output]): - super().__init__() - - self._outputs = outputs - - def transform(self, sample: Any) -> Any: - if isinstance(sample, Mapping): - return {key: output.transform(sample[key]) for key, output in self._outputs.items()} - raise ValueError("The model output must be a mapping when using an OutputMapping.") - - def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): - for output in self._outputs.values(): - output.attach_data_pipeline_state(data_pipeline_state) + return self.transform(sample) class _OutputProcessor(torch.nn.Module): diff --git a/flash/core/model.py b/flash/core/model.py index b4c5aca3c7..405f156fc8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -39,7 +39,7 @@ from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource -from flash.core.data.io.output import Output, OutputMapping +from flash.core.data.io.output import Output from flash.core.data.process import Deserializer, DeserializerMapping, Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY @@ -319,8 +319,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check deserialize the input preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task. - output: Either a single :class:`~flash.core.data.io.output.Output` or a mapping of these to - serialize the output e.g. convert the model output into the desired output format when predicting. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY @@ -630,9 +629,7 @@ def output(self) -> Optional[Output]: @torch.jit.unused @output.setter - def output(self, output: Union[Output, Mapping[str, Output]]): - if isinstance(output, Mapping): - output = OutputMapping(output) + def output(self, output: Output): self._output = output @torch.jit.unused @@ -662,7 +659,7 @@ def serializer(self) -> Optional[Output]: "It will be removed in v%(remove_in)s.", stream=functools.partial(warn, category=FutureWarning), ) - def serializer(self, serializer: Union[Output, Mapping[str, Output]]): + def serializer(self, serializer: Output): self.output = serializer def build_data_pipeline( diff --git a/flash/core/utilities/types.py b/flash/core/utilities/types.py index e7597d963a..ec968792d8 100644 --- a/flash/core/utilities/types.py +++ b/flash/core/utilities/types.py @@ -16,4 +16,4 @@ DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]] PREPROCESS_TYPE = Optional[Preprocess] POSTPROCESS_TYPE = Optional[Postprocess] -OUTPUT_TYPE = Optional[Union[Output, Mapping[str, Output]]] +OUTPUT_TYPE = Optional[Output] diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index ff2599df95..90e0181523 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -60,8 +60,7 @@ def fn_resnet(pretrained: bool = True): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. multi_label: Whether the targets are multi-label or not. - output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such - to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn` for doing meta-learning research training_strategy_kwargs: Additional kwargs for setting the training strategy diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index b0075a6956..655de7e965 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -42,8 +42,7 @@ class ObjectDetector(AdapterTask): pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training - output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such - to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. kwargs: additional kwargs nessesary for initializing the backbone task """ diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index a9589b20ad..17f96a0fed 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -68,7 +68,7 @@ class SemanticSegmentation(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.IOU`. learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. - output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. postprocess: :class:`~flash.core.data.process.Postprocess` use for post processing samples. """ diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index efe402909e..eafa0a3fd8 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -53,7 +53,7 @@ class PointCloudObjectDetector(Task): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. lambda_loss_cls: The value to scale the loss classification. lambda_loss_bbox: The value to scale the bounding boxes loss. lambda_loss_dir: The value to scale the bounding boxes direction loss. diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 227bf63e59..4c4d33f6db 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -84,7 +84,7 @@ class PointCloudSegmentation(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index cec72473bd..6c1ad46569 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -42,7 +42,7 @@ class TabularClassifier(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. - output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. **tabnet_kwargs: Optional additional arguments for the TabNet model, see `pytorch_tabnet `_. """ diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index d5090a1298..f0837ad14e 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -40,7 +40,7 @@ class TabularRegressor(RegressionTask): `metric(preds,target)` and return a single scalar tensor. learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. **tabnet_kwargs: Optional additional arguments for the TabNet model, see `pytorch_tabnet `_. """ diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 804549d83e..66e2ee2253 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -40,7 +40,7 @@ class TemplateSKLearnClassifier(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = TEMPLATE_BACKBONES diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index da491f2026..950b5dc902 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -46,7 +46,7 @@ class TextClassifier(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. - output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 6ae0340355..20f2890fbd 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -93,8 +93,7 @@ class VideoClassifier(ClassificationTask): head: either a `nn.Module` or a callable function that converts the features extrated from the backbone into class log probabilities (assuming default loss function). If `None`, will default to using a single linear layer. - output: A instance of :class:`~flash.core.data.io.output.Output` that determines how the output - should be serialized e.g. convert the model output into the desired output format when predicting. + output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index 1d18ad1315..e875c2f86b 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -14,33 +14,25 @@ import os from unittest.mock import Mock -import pytest import torch from torch.utils.data import DataLoader from flash.core.classification import Labels from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import LabelsState -from flash.core.data.io.output import Output, OutputMapping +from flash.core.data.io.output import Output from flash.core.data.process import DefaultPreprocess -from flash.core.data.properties import ProcessState from flash.core.model import Task from flash.core.trainer import Trainer -def test_output_enable_disable(): - """Tests that ``Output`` can be enabled and disabled correctly.""" - +def test_output(): + """Tests basic ``Output`` methods.""" my_output = Output() assert my_output.transform("test") == "test" - my_output.transform = Mock() - - my_output.disable() - assert my_output("test") == "test" - my_output.transform.assert_not_called() - my_output.enable() + my_output.transform = Mock() my_output("test") my_output.transform.assert_called_once() @@ -65,45 +57,3 @@ def __init__(self): model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model._data_pipeline_state, DataPipelineState) assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) - - -def test_output_mapping(): - """Tests that ``OutputMapping`` correctly passes its inputs to the underlying outputs. - - Also checks that state is retrieved / loaded correctly. - """ - - output1 = Output() - output1.transform = Mock(return_value="test1") - - class output1State(ProcessState): - pass - - output2 = Output() - output2.transform = Mock(return_value="test2") - - class output2State(ProcessState): - pass - - output_mapping = OutputMapping({"key1": output1, "key2": output2}) - assert output_mapping({"key1": "output1", "key2": "output2"}) == {"key1": "test1", "key2": "test2"} - output1.transform.assert_called_once_with("output1") - output2.transform.assert_called_once_with("output2") - - with pytest.raises(ValueError, match="output must be a mapping"): - output_mapping("not a mapping") - - output1_state = output1State() - output2_state = output2State() - - output1.set_state(output1_state) - output2.set_state(output2_state) - - data_pipeline_state = DataPipelineState() - output_mapping.attach_data_pipeline_state(data_pipeline_state) - - assert output1._data_pipeline_state is data_pipeline_state - assert output2._data_pipeline_state is data_pipeline_state - - assert data_pipeline_state.get_state(output1State) is output1_state - assert data_pipeline_state.get_state(output2State) is output2_state