From d1be93cd2d5b59af8bc40db2e6a606688b9d071c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 17:33:36 +0000 Subject: [PATCH] Rename Serializer to Output and move to `flash.core.data.io.output` (#927) Co-authored-by: Ananya Harsh Jha --- CHANGELOG.md | 9 ++ docs/source/api/core.rst | 4 +- docs/source/api/data.rst | 13 +- docs/source/api/flash.rst | 2 +- docs/source/api/image.rst | 6 +- docs/source/common/finetuning_example.rst | 4 +- docs/source/general/data.rst | 14 +-- docs/source/general/predictions.rst | 8 +- docs/source/integrations/fiftyone.rst | 4 +- docs/source/template/optional.rst | 10 +- flash/__init__.py | 4 +- flash/audio/speech_recognition/model.py | 8 +- flash/core/classification.py | 58 ++++----- flash/core/data/batch.py | 24 +--- flash/core/data/data_pipeline.py | 23 ++-- flash/core/data/io/__init__.py | 0 flash/core/data/io/output.py | 85 +++++++++++++ .../core/data/{serialization.py => output.py} | 8 +- flash/core/data/process.py | 97 +++++++-------- flash/core/integrations/fiftyone/utils.py | 2 +- flash/core/model.py | 116 +++++++++++------- flash/core/regression.py | 6 +- flash/core/serve/flash_components.py | 8 +- flash/core/serve/types/image.py | 2 +- flash/core/utilities/types.py | 5 +- flash/image/classification/model.py | 8 +- flash/image/detection/model.py | 10 +- .../detection/{serialization.py => output.py} | 10 +- flash/image/face_detection/model.py | 14 +-- flash/image/instance_segmentation/model.py | 8 +- flash/image/keypoint_detection/model.py | 8 +- flash/image/segmentation/data.py | 2 +- flash/image/segmentation/model.py | 10 +- .../{serialization.py => output.py} | 16 +-- flash/image/style_transfer/model.py | 8 +- flash/pointcloud/detection/model.py | 12 +- flash/pointcloud/segmentation/model.py | 12 +- flash/tabular/classification/model.py | 8 +- flash/tabular/regression/model.py | 6 +- flash/template/classification/model.py | 8 +- flash/text/classification/model.py | 8 +- flash/video/classification/model.py | 8 +- .../image_classification_active_learning.py | 2 +- .../fiftyone/image_classification.py | 4 +- .../image_classification_fiftyone_datasets.py | 4 +- .../integrations/fiftyone/object_detection.py | 6 +- .../labelstudio/image_classification.py | 2 +- .../semantic_segmentation/inference_server.py | 4 +- .../inference_server.py | 2 +- tests/core/data/io/__init__.py | 0 tests/core/data/io/test_output.py | 109 ++++++++++++++++ tests/core/data/test_batch.py | 2 +- tests/core/data/test_data_pipeline.py | 7 +- tests/core/data/test_process.py | 89 +------------- tests/core/test_classification.py | 34 ++--- .../classification/test_active_learning.py | 4 +- tests/image/classification/test_model.py | 2 +- .../{test_serialization.py => test_output.py} | 10 +- .../{test_serialization.py => test_output.py} | 12 +- 59 files changed, 549 insertions(+), 420 deletions(-) create mode 100644 flash/core/data/io/__init__.py create mode 100644 flash/core/data/io/output.py rename flash/core/data/{serialization.py => output.py} (76%) rename flash/image/detection/{serialization.py => output.py} (92%) rename flash/image/segmentation/{serialization.py => output.py} (89%) create mode 100644 tests/core/data/io/__init__.py create mode 100644 tests/core/data/io/test_output.py rename tests/image/detection/{test_serialization.py => test_output.py} (89%) rename tests/image/segmentation/{test_serialization.py => test_output.py} (89%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bd9a8dbac..457eddcc2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed classes named `*Serializer` and properties / variables named `serializer` to be `*Output` and `output` respectively ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) + +### Deprecated + +- Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) + +- Deprecated `Task.serializer` in favour of `Task.output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) + ### Fixed + ## [0.5.2] - 2021-11-05 ### Added diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 8e7b011d2d..011c95bed6 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -27,12 +27,12 @@ _________________________ :template: classtemplate.rst ~flash.core.classification.Classes - ~flash.core.classification.ClassificationSerializer + ~flash.core.classification.ClassificationOutput ~flash.core.classification.ClassificationTask ~flash.core.classification.FiftyOneLabels ~flash.core.classification.Labels ~flash.core.classification.Logits - ~flash.core.classification.PredsClassificationSerializer + ~flash.core.classification.PredsClassificationOutput ~flash.core.classification.Probabilities flash.core.finetuning diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 497fd916e9..00e35b8529 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -7,6 +7,17 @@ flash.core.data :local: :backlinks: top +flash.core.data.io.output +_________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.io.output.Output + ~flash.core.data.io.output.OutputMapping + flash.core.data.auto_dataset ____________________________ @@ -114,8 +125,6 @@ _______________________ ~flash.core.data.process.Deserializer ~flash.core.data.process.Postprocess ~flash.core.data.process.Preprocess - ~flash.core.data.process.SerializerMapping - ~flash.core.data.process.Serializer flash.core.data.properties __________________________ diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst index 06540aad69..bd087b64d7 100644 --- a/docs/source/api/flash.rst +++ b/docs/source/api/flash.rst @@ -12,6 +12,6 @@ flash ~flash.core.data.callback.FlashCallback ~flash.core.data.process.Preprocess ~flash.core.data.process.Postprocess - ~flash.core.data.process.Serializer + ~flash.core.data.io.output.Output ~flash.core.model.Task ~flash.core.trainer.Trainer diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index ded8bccd33..84e351a74c 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -45,7 +45,7 @@ ________________ detection.data.FiftyOneParser detection.data.ObjectDetectionFiftyOneDataSource detection.data.ObjectDetectionPreprocess - detection.serialization.FiftyOneDetectionLabels + detection.output.FiftyOneDetectionLabels Keypoint Detection __________________ @@ -102,8 +102,8 @@ ____________ segmentation.data.SemanticSegmentationFiftyOneDataSource segmentation.data.SemanticSegmentationDeserializer segmentation.model.SemanticSegmentationPostprocess - segmentation.serialization.FiftyOneSegmentationLabels - segmentation.serialization.SegmentationLabels + segmentation.output.FiftyOneSegmentationLabels + segmentation.output.SegmentationLabels .. autosummary:: :toctree: generated/ diff --git a/docs/source/common/finetuning_example.rst b/docs/source/common/finetuning_example.rst index b45b0cfd97..63022a9d7d 100644 --- a/docs/source/common/finetuning_example.rst +++ b/docs/source/common/finetuning_example.rst @@ -55,8 +55,8 @@ Once you've finetuned, use the model to predict: .. testcode:: finetune - # Serialize predictions as labels, automatically inferred from the training data in part 2. - model.serializer = Labels() + # Output predictions as labels, automatically inferred from the training data in part 2. + model.output = Labels() predictions = model.predict( [ diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 8e815c5a83..c5d51d1f96 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -26,7 +26,7 @@ Here are common terms you need to be familiar with: * - :class:`~flash.core.data.data_module.DataModule` - The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders. * - :class:`~flash.core.data.data_pipeline.DataPipeline` - - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects. + - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.io.output.Output` objects. * - :class:`~flash.core.data.data_source.DataSource` - The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names). * - :class:`~flash.core.data.process.Preprocess` @@ -37,8 +37,8 @@ Here are common terms you need to be familiar with: * - :class:`~flash.core.data.process.Postprocess` - The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic. The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export. - * - :class:`~flash.core.data.process.Serializer` - - The :class:`~flash.core.data.process.Serializer` provides a single :meth:`~flash.core.data.process.Serializer.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction. + * - :class:`~flash.core.data.io.output.Output` + - The :class:`~flash.core.data.io.output.Output` provides a single :meth:`~flash.core.data.io.output.Output.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction. ******************************************* @@ -59,7 +59,7 @@ Usually, extra processing logic should be added to bridge the gap between traini The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms. -The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). +The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow. @@ -383,18 +383,18 @@ Example:: predictions = lightning_module(data) -Postprocess and Serializer +Postprocess and Output __________________________ Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash :class:`~flash.core.data.data_pipeline.DataPipeline` will execute the :class:`~flash.core.data.process.Postprocess` hooks and the -:class:`~flash.core.data.process.Serializer` behind the scenes. +:class:`~flash.core.data.io.output.Output` behind the scenes. First, the :meth:`~flash.core.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions. Then, the :meth:`~flash.core.data.process.Postprocess.uncollate` will split the batch into individual predictions. Next, the :meth:`~flash.core.data.process.Postprocess.per_sample_transform` will be applied on each prediction. -Finally, the :meth:`~flash.core.data.process.Serializer.serialize` method will be called to serialize the predictions. +Finally, the :meth:`~flash.core.data.io.output.Output.serialize` method will be called to serialize the predictions. .. note:: The transform can be applied either on device or ``CPU``. diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index 88d7e5cd9d..3181d9766e 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -57,8 +57,8 @@ Predict on a csv file Serializing predictions ======================= -To change how predictions are serialized you can attach a :class:`~flash.core.data.process.Serializer` to your -:class:`~flash.core.model.Task`. For example, you can choose to serialize outputs as probabilities (for more options see the API +To change the output format of predictions you can attach an :class:`~flash.core.data.io.output.Output` to your +:class:`~flash.core.model.Task`. For example, you can choose to output probabilities (for more options see the API reference below). @@ -77,8 +77,8 @@ reference below). "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" ) - # 3. Attach the Serializer - model.serializer = Probabilities() + # 3. Attach the Output + model.output = Probabilities() # 4. Predict whether the image contains an ant or a bee predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst index 8592fad47b..a44671ae99 100644 --- a/docs/source/integrations/fiftyone.rst +++ b/docs/source/integrations/fiftyone.rst @@ -47,8 +47,8 @@ semantic segmentation tasks. Doing so is as easy as updating your model to use one of the following serializers: * :class:`FiftyOneLabels(return_filepath=True)` -* :class:`FiftyOneSegmentationLabels(return_filepath=True)` -* :class:`FiftyOneDetectionLabels(return_filepath=True)` +* :class:`FiftyOneSegmentationLabels(return_filepath=True)` +* :class:`FiftyOneDetectionLabels(return_filepath=True)` The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize your predictions in the diff --git a/docs/source/template/optional.rst b/docs/source/template/optional.rst index 9c27bb092c..1b34ac3f8f 100644 --- a/docs/source/template/optional.rst +++ b/docs/source/template/optional.rst @@ -20,21 +20,21 @@ Here's how we create our transforms in the :class:`~flash.image.classification.d :language: python :pyobject: ImageClassificationPreprocess.default_transforms -Add output serializers to your Task -====================================== +Add outputs to your Task +======================== We recommend that you do most of the heavy lifting in the :class:`~flash.core.data.process.Postprocess`. Specifically, it should include any formatting and transforms that should always be applied to the predictions. -If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.process.Serializer` implementations in a ``serialization.py`` file. +If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.io.output.Output` implementations in an ``output.py`` file. Some good examples are in `flash/core/classification.py `_. -Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.process.Serializer`: +Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.io.output.Output`: .. literalinclude:: ../../../flash/core/classification.py :language: python :pyobject: Classes -Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.process.Serializer`: +Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.io.output.Output`: .. literalinclude:: ../../../flash/core/classification.py :language: python diff --git a/flash/__init__.py b/flash/__init__.py index 4b39185dad..eb04b45bb9 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -24,6 +24,7 @@ from flash.core.data.data_source import DataSource from flash.core.data.datasets import FlashDataset, FlashIterableDataset from flash.core.data.input_transform import InputTransform + from flash.core.data.io.output import Output from flash.core.data.process import Postprocess, Preprocess, Serializer from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 @@ -44,9 +45,10 @@ "FlashCallback", "FlashDataset", "FlashIterableDataset", - "Preprocess", "InputTransform", + "Output", "Postprocess", + "Preprocess", "Serializer", "Task", "Trainer", diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 18f215b395..9d895279d4 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -25,7 +25,7 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE if _AUDIO_AVAILABLE: from transformers import Wav2Vec2Processor @@ -41,7 +41,7 @@ class SpeechRecognition(Task): learning_rate: Learning rate to use for training, defaults to ``1e-5``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. """ backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES @@ -54,7 +54,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-5, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -68,7 +68,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, - serializer=serializer, + output=output, ) self.save_hyperparameters() diff --git a/flash/core/classification.py b/flash/core/classification.py index 884dffbdfb..90e837b389 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -20,7 +20,7 @@ from flash.core.adapter import AdapterTask from flash.core.data.data_source import DefaultDataKeys, LabelsState -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires @@ -68,7 +68,7 @@ def __init__( loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + output: Optional[Union[Output, Mapping[str, Output]]] = None, **kwargs, ) -> None: @@ -78,7 +78,7 @@ def __init__( *args, loss_fn=loss_fn, metrics=metrics, - serializer=serializer or Classes(multi_label=multi_label), + output=output or Classes(multi_label=multi_label), **kwargs, ) @@ -91,7 +91,7 @@ def __init__( loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, multi_label: bool = False, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + output: Optional[Union[Output, Mapping[str, Output]]] = None, **kwargs, ) -> None: @@ -101,13 +101,13 @@ def __init__( *args, loss_fn=loss_fn, metrics=metrics, - serializer=serializer or Classes(multi_label=multi_label), + output=output or Classes(multi_label=multi_label), **kwargs, ) -class ClassificationSerializer(Serializer): - """A base class for classification serializers. +class ClassificationOutput(Output): + """A base class for classification outputs. Args: multi_label: If true, treats outputs as multi label logits. @@ -123,12 +123,12 @@ def multi_label(self) -> bool: return self._mutli_label -class PredsClassificationSerializer(ClassificationSerializer): - """A :class:`~flash.core.classification.ClassificationSerializer` which gets the +class PredsClassificationOutput(ClassificationOutput): + """A :class:`~flash.core.classification.ClassificationOutput` which gets the :attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample. """ - def serialize(self, sample: Any) -> Any: + def transform(self, sample: Any) -> Any: if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample: sample = sample[DefaultDataKeys.PREDS] if not isinstance(sample, torch.Tensor): @@ -136,26 +136,26 @@ def serialize(self, sample: Any) -> Any: return sample -class Logits(PredsClassificationSerializer): - """A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list.""" +class Logits(PredsClassificationOutput): + """A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list.""" - def serialize(self, sample: Any) -> Any: - return super().serialize(sample).tolist() + def transform(self, sample: Any) -> Any: + return super().transform(sample).tolist() -class Probabilities(PredsClassificationSerializer): - """A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a +class Probabilities(PredsClassificationOutput): + """A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a list.""" - def serialize(self, sample: Any) -> Any: - sample = super().serialize(sample) + def transform(self, sample: Any) -> Any: + sample = super().transform(sample) if self.multi_label: return torch.sigmoid(sample).tolist() return torch.softmax(sample, -1).tolist() -class Classes(PredsClassificationSerializer): - """A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and +class Classes(PredsClassificationOutput): + """A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and converts to a list. Args: @@ -168,8 +168,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5): self.threshold = threshold - def serialize(self, sample: Any) -> Union[int, List[int]]: - sample = super().serialize(sample) + def transform(self, sample: Any) -> Union[int, List[int]]: + sample = super().transform(sample) if self.multi_label: one_hot = (sample.sigmoid() > self.threshold).int().tolist() result = [] @@ -181,7 +181,7 @@ def serialize(self, sample: Any) -> Union[int, List[int]]: class Labels(Classes): - """A :class:`.Serializer` which converts the model outputs (either logits or probabilities) to the label of the + """A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the argmax classification. Args: @@ -198,7 +198,7 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False if labels is not None: self.set_state(LabelsState(labels)) - def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: + def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None if self._labels is not None: @@ -208,18 +208,18 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: if state is not None: labels = state.labels - classes = super().serialize(sample) + classes = super().transform(sample) if labels is not None: if self.multi_label: return [labels[cls] for cls in classes] return labels[classes] - rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning) + rank_zero_warn("No LabelsState was found, this output will act as a Classes output.", UserWarning) return classes -class FiftyOneLabels(ClassificationSerializer): - """A :class:`.Serializer` which converts the model outputs to FiftyOne classification format. +class FiftyOneLabels(ClassificationOutput): + """A :class:`.Output` which converts the model outputs to FiftyOne classification format. Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not @@ -254,7 +254,7 @@ def __init__( if labels is not None: self.set_state(LabelsState(labels)) - def serialize( + def transform( self, sample: Any, ) -> Union[Classification, Classifications, Dict[str, Any]]: diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 91b997dadb..4dfbf33dac 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -28,7 +28,7 @@ from flash.core.utilities.stages import RunningStage if TYPE_CHECKING: - from flash.core.data.process import Deserializer, Preprocess, Serializer + from flash.core.data.process import Deserializer, Preprocess class _Sequential(torch.nn.Module): @@ -135,18 +135,6 @@ def forward(self, sample: str): return sample -class _SerializeProcessor(torch.nn.Module): - def __init__( - self, - serializer: "Serializer", - ): - super().__init__() - self.serializer = convert_to_modules(serializer) - - def forward(self, sample): - return self.serializer(sample) - - class _Preprocessor(torch.nn.Module): """ This class is used to encapsultate the following functions of a Preprocess Object: @@ -274,7 +262,7 @@ def __init__( uncollate_fn: Callable, per_batch_transform: Callable, per_sample_transform: Callable, - serializer: Optional[Callable], + output: Optional[Callable], save_fn: Optional[Callable] = None, save_per_sample: bool = False, is_serving: bool = False, @@ -283,7 +271,7 @@ def __init__( self.uncollate_fn = convert_to_modules(uncollate_fn) self.per_batch_transform = convert_to_modules(per_batch_transform) self.per_sample_transform = convert_to_modules(per_sample_transform) - self.serializer = convert_to_modules(serializer) + self.output = convert_to_modules(output) self.save_fn = convert_to_modules(save_fn) self.save_per_sample = convert_to_modules(save_per_sample) self.is_serving = is_serving @@ -304,8 +292,8 @@ def forward(self, batch: Sequence[Any]): final_preds = [self.per_sample_transform(sample) for sample in uncollated] - if self.serializer is not None: - final_preds = [self.serializer(sample) for sample in final_preds] + if self.output is not None: + final_preds = [self.output(sample) for sample in final_preds] if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor): final_preds = torch.stack(final_preds) @@ -326,7 +314,7 @@ def __str__(self) -> str: f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" f"\t(uncollate_fn): {str(self.uncollate_fn)}\n" f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" - f"\t(serializer): {str(self.serializer)}" + f"\t(output): {str(self.output)}" ) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 15f0afd035..2dcdcb7294 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -24,9 +24,10 @@ import flash from flash.core.data.auto_dataset import IterableAutoDataset -from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor +from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential from flash.core.data.data_source import DataSource -from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer +from flash.core.data.io.output import _OutputProcessor, Output +from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0 @@ -103,26 +104,26 @@ def __init__( preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, deserializer: Optional[Deserializer] = None, - serializer: Optional[Serializer] = None, + output: Optional[Output] = None, ) -> None: self.data_source = data_source self._preprocess_pipeline = preprocess or DefaultPreprocess() self._postprocess_pipeline = postprocess or Postprocess() - self._serializer = serializer or Serializer() + self._output = output or Output() self._deserializer = deserializer or Deserializer() self._running_stage = None def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, - :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will + :class:`.Postprocess`, and :class:`.Output`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() if self.data_source is not None: self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) - self._serializer.attach_data_pipeline_state(data_pipeline_state) + self._output.attach_data_pipeline_state(data_pipeline_state) return data_pipeline_state @property @@ -181,8 +182,8 @@ def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: def postprocessor(self, running_stage: RunningStage, is_serving=False) -> _Postprocessor: return self._create_uncollate_postprocessors(running_stage, is_serving=is_serving) - def serialize_processor(self) -> _SerializeProcessor: - return _SerializeProcessor(self._serializer) + def output_processor(self) -> _OutputProcessor: + return _OutputProcessor(self._output) @classmethod def _resolve_function_hierarchy( @@ -477,7 +478,7 @@ def _create_uncollate_postprocessors( getattr(postprocess, func_names["uncollate"]), getattr(postprocess, func_names["per_batch_transform"]), getattr(postprocess, func_names["per_sample_transform"]), - serializer=None if is_serving else self._serializer, + output=None if is_serving else self._output, save_fn=save_fn, save_per_sample=save_per_sample, is_serving=is_serving, @@ -588,7 +589,7 @@ def __str__(self) -> str: data_source: DataSource = self.data_source preprocess: Preprocess = self._preprocess_pipeline postprocess: Postprocess = self._postprocess_pipeline - serializer: Serializer = self._serializer + output: Output = self._output deserializer: Deserializer = self._deserializer return ( f"{self.__class__.__name__}(" @@ -596,7 +597,7 @@ def __str__(self) -> str: f"deserializer={deserializer}, " f"preprocess={preprocess}, " f"postprocess={postprocess}, " - f"serializer={serializer})" + f"output={output})" ) diff --git a/flash/core/data/io/__init__.py b/flash/core/data/io/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/data/io/output.py b/flash/core/data/io/output.py new file mode 100644 index 0000000000..18d50b73a1 --- /dev/null +++ b/flash/core/data/io/output.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Mapping + +import torch + +import flash +from flash.core.data.properties import Properties +from flash.core.data.utils import convert_to_modules + + +class Output(Properties): + """An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which + is used to convert the model output into the desired output format when predicting.""" + + def __init__(self): + super().__init__() + self._is_enabled = True + + def enable(self): + """Enable output transformation.""" + self._is_enabled = True + + def disable(self): + """Disable output transformation.""" + self._is_enabled = False + + @staticmethod + def transform(sample: Any) -> Any: + """Convert the given sample into the desired output format. + + Args: + sample: The output from the :class:`.Postprocess`. + + Returns: + The converted output. + """ + return sample + + def __call__(self, sample: Any) -> Any: + if self._is_enabled: + return self.transform(sample) + return sample + + +class OutputMapping(Output): + """If the model output is a dictionary, then the :class:`.OutputMapping` enables each entry in the dictionary + to be passed to it's own :class:`.Output`.""" + + def __init__(self, outputs: Mapping[str, Output]): + super().__init__() + + self._outputs = outputs + + def transform(self, sample: Any) -> Any: + if isinstance(sample, Mapping): + return {key: output.transform(sample[key]) for key, output in self._outputs.items()} + raise ValueError("The model output must be a mapping when using an OutputMapping.") + + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): + for output in self._outputs.values(): + output.attach_data_pipeline_state(data_pipeline_state) + + +class _OutputProcessor(torch.nn.Module): + def __init__( + self, + output: "Output", + ): + super().__init__() + self.output = convert_to_modules(output) + + def forward(self, sample): + return self.output(sample) diff --git a/flash/core/data/serialization.py b/flash/core/data/output.py similarity index 76% rename from flash/core/data/serialization.py rename to flash/core/data/output.py index 190bbffe5b..ad91f4494e 100644 --- a/flash/core/data/serialization.py +++ b/flash/core/data/output.py @@ -14,11 +14,11 @@ from typing import Any, List, Union from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output -class Preds(Serializer): - """A :class:`~flash.core.data.process.Serializer` which returns the "preds" from the model outputs.""" +class Preds(Output): + """A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs.""" - def serialize(self, sample: Any) -> Union[int, List[int]]: + def transform(self, sample: Any) -> Union[int, List[int]]: return sample.get(DefaultDataKeys.PREDS, sample) if isinstance(sample, dict) else sample diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 641a29a501..a9b56d312b 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect import os from abc import ABC, abstractclassmethod, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import torch +from _warnings import warn +from deprecate import deprecated from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from torch.utils.data._utils.collate import default_collate @@ -25,6 +28,7 @@ from flash.core.data.batch import default_uncollate from flash.core.data.callback import FlashCallback from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.output import Output from flash.core.data.properties import ProcessState, Properties from flash.core.data.states import ( CollateFn, @@ -582,59 +586,6 @@ def _save_sample(self, sample: Any) -> None: self.save_sample(sample, self.format_sample_save_path(self._save_path)) -class Serializer(Properties): - """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model output - into the desired output format when predicting.""" - - def __init__(self): - super().__init__() - self._is_enabled = True - - def enable(self): - """Enable serialization.""" - self._is_enabled = True - - def disable(self): - """Disable serialization.""" - self._is_enabled = False - - @staticmethod - def serialize(sample: Any) -> Any: - """Serialize the given sample into the desired output format. - - Args: - sample: The output from the :class:`.Postprocess`. - - Returns: - The serialized output. - """ - return sample - - def __call__(self, sample: Any) -> Any: - if self._is_enabled: - return self.serialize(sample) - return sample - - -class SerializerMapping(Serializer): - """If the model output is a dictionary, then the :class:`.SerializerMapping` enables each entry in the - dictionary to be passed to it's own :class:`.Serializer`.""" - - def __init__(self, serializers: Mapping[str, Serializer]): - super().__init__() - - self._serializers = serializers - - def serialize(self, sample: Any) -> Any: - if isinstance(sample, Mapping): - return {key: serializer.serialize(sample[key]) for key, serializer in self._serializers.items()} - raise ValueError("The model output must be a mapping when using a SerializerMapping.") - - def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): - for serializer in self._serializers.values(): - serializer.attach_data_pipeline_state(data_pipeline_state) - - class Deserializer(Properties): """Deserializer.""" @@ -651,7 +602,7 @@ def __call__(self, sample: Any) -> Any: class DeserializerMapping(Deserializer): - # TODO: This is essentially a duplicate of SerializerMapping, should be abstracted away somewhere + # TODO: This is essentially a duplicate of OutputMapping, should be abstracted away somewhere """Deserializer Mapping.""" def __init__(self, deserializers: Mapping[str, Deserializer]): @@ -667,3 +618,41 @@ def deserialize(self, sample: Any) -> Any: def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): for deserializer in self._deserializers.values(): deserializer.attach_data_pipeline_state(data_pipeline_state) + + +class Serializer(Output): + """Deprecated. + + Use ``Output`` instead. + """ + + @deprecated( + None, + "0.6.0", + "0.7.0", + template_mgs="`Serializer` was deprecated in v%(deprecated_in)s in favor of `Output`. " + "It will be removed in v%(remove_in)s.", + stream=functools.partial(warn, category=FutureWarning), + ) + def __init__(self): + super().__init__() + self._is_enabled = True + + @staticmethod + @deprecated( + None, + "0.6.0", + "0.7.0", + template_mgs="`Serializer` was deprecated in v%(deprecated_in)s in favor of `Output`. " + "It will be removed in v%(remove_in)s.", + stream=functools.partial(warn, category=FutureWarning), + ) + def serialize(sample: Any) -> Any: + """Deprecated. + + Use ``Output.transform`` instead. + """ + return sample + + def transform(self, sample: Any) -> Any: + return self.serialize(sample) diff --git a/flash/core/integrations/fiftyone/utils.py b/flash/core/integrations/fiftyone/utils.py index d5c8ae3fb3..021a46ddc9 100644 --- a/flash/core/integrations/fiftyone/utils.py +++ b/flash/core/integrations/fiftyone/utils.py @@ -21,7 +21,7 @@ def visualize( wait: Optional[bool] = False, **kwargs ) -> Optional[Session]: - """Visualizes predictions from a model with a FiftyOne Serializer in the + """Visualizes predictions from a model with a FiftyOne Output in the :ref:`FiftyOne App `. This method can be used in all of the following environments: diff --git a/flash/core/model.py b/flash/core/model.py index 85e13aa8de..b4c5aca3c7 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -18,10 +18,12 @@ from copy import deepcopy from importlib import import_module from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union +from warnings import warn import pytorch_lightning as pl import torch import torchmetrics +from deprecate import deprecated from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config @@ -37,14 +39,8 @@ from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource -from flash.core.data.process import ( - Deserializer, - DeserializerMapping, - Postprocess, - Preprocess, - Serializer, - SerializerMapping, -) +from flash.core.data.io.output import Output, OutputMapping +from flash.core.data.process import Deserializer, DeserializerMapping, Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY @@ -61,9 +57,9 @@ METRICS_TYPE, MODEL_TYPE, OPTIMIZER_TYPE, + OUTPUT_TYPE, POSTPROCESS_TYPE, PREPROCESS_TYPE, - SERIALIZER_TYPE, ) @@ -323,7 +319,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check deserialize the input preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task. - serializer: Either a single :class:`~flash.core.data.process.Serializer` or a mapping of these to + output: Either a single :class:`~flash.core.data.io.output.Output` or a mapping of these to serialize the output e.g. convert the model output into the desired output format when predicting. """ @@ -343,7 +339,7 @@ def __init__( deserializer: DESERIALIZER_TYPE = None, preprocess: PREPROCESS_TYPE = None, postprocess: POSTPROCESS_TYPE = None, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, ): super().__init__() if model is not None: @@ -362,11 +358,11 @@ def __init__( self._deserializer: Optional[Deserializer] = None self._preprocess: Optional[Preprocess] = preprocess self._postprocess: Optional[Postprocess] = postprocess - self._serializer: Optional[Serializer] = None + self._output: Optional[Output] = None - # Explicitly set the serializer to call the setter + # Explicitly set the output to call the setter self.deserializer = deserializer - self.serializer = serializer + self.output = output def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """Implement the core logic for the training/validation/test step. By default this includes: @@ -574,28 +570,28 @@ def _resolve( old_deserializer: Optional[Deserializer], old_preprocess: Optional[Preprocess], old_postprocess: Optional[Postprocess], - old_serializer: Optional[Serializer], + old_output: Optional[Output], new_deserializer: Optional[Deserializer], new_preprocess: Optional[Preprocess], new_postprocess: Optional[Postprocess], - new_serializer: Optional[Serializer], - ) -> Tuple[Optional[Deserializer], Optional[Preprocess], Optional[Postprocess], Optional[Serializer]]: + new_output: Optional[Output], + ) -> Tuple[Optional[Deserializer], Optional[Preprocess], Optional[Postprocess], Optional[Output]]: """Resolves the correct :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and - :class:`~flash.core.data.process.Serializer` to use, choosing ``new_*`` if it is not None or a base class + :class:`~flash.core.data.io.output.Output` to use, choosing ``new_*`` if it is not None or a base class (:class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, or - :class:`~flash.core.data.process.Serializer`) and ``old_*`` otherwise. + :class:`~flash.core.data.io.output.Output`) and ``old_*`` otherwise. Args: old_preprocess: :class:`~flash.core.data.process.Preprocess` to be overridden. old_postprocess: :class:`~flash.core.data.process.Postprocess` to be overridden. - old_serializer: :class:`~flash.core.data.process.Serializer` to be overridden. + old_output: :class:`~flash.core.data.io.output.Output` to be overridden. new_preprocess: :class:`~flash.core.data.process.Preprocess` to override with. new_postprocess: :class:`~flash.core.data.process.Postprocess` to override with. - new_serializer: :class:`~flash.core.data.process.Serializer` to override with. + new_output: :class:`~flash.core.data.io.output.Output` to override with. Returns: The resolved :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, - and :class:`~flash.core.data.process.Serializer`. + and :class:`~flash.core.data.io.output.Output`. """ deserializer = old_deserializer if new_deserializer is not None and type(new_deserializer) != Deserializer: @@ -609,11 +605,11 @@ def _resolve( if new_postprocess is not None and type(new_postprocess) != Postprocess: postprocess = new_postprocess - serializer = old_serializer - if new_serializer is not None and type(new_serializer) != Serializer: - serializer = new_serializer + output = old_output + if new_output is not None and type(new_output) != Output: + output = new_output - return deserializer, preprocess, postprocess, serializer + return deserializer, preprocess, postprocess, output @torch.jit.unused @property @@ -628,20 +624,46 @@ def deserializer(self, deserializer: Union[Deserializer, Mapping[str, Deserializ @torch.jit.unused @property - def serializer(self) -> Optional[Serializer]: - """The current :class:`.Serializer` associated with this model. + def output(self) -> Optional[Output]: + """The current :class:`.Output` associated with this model.""" + return self._output + + @torch.jit.unused + @output.setter + def output(self, output: Union[Output, Mapping[str, Output]]): + if isinstance(output, Mapping): + output = OutputMapping(output) + self._output = output - If this property was set to a mapping - (e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`. + @torch.jit.unused + @property + @deprecated( + None, + "0.6.0", + "0.7.0", + template_mgs="`Task.serializer` was deprecated in v%(deprecated_in)s in favor of `Task.output`. " + "It will be removed in v%(remove_in)s.", + stream=functools.partial(warn, category=FutureWarning), + ) + def serializer(self) -> Optional[Output]: + """Deprecated. + + Use ``Task.output`` instead. """ - return self._serializer + return self.output @torch.jit.unused @serializer.setter - def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): - if isinstance(serializer, Mapping): - serializer = SerializerMapping(serializer) - self._serializer = serializer + @deprecated( + None, + "0.6.0", + "0.7.0", + template_mgs="`Task.serializer` was deprecated in v%(deprecated_in)s in favor of `Task.output`. " + "It will be removed in v%(remove_in)s.", + stream=functools.partial(warn, category=FutureWarning), + ) + def serializer(self, serializer: Union[Output, Mapping[str, Output]]): + self.output = serializer def build_data_pipeline( self, @@ -668,7 +690,7 @@ def build_data_pipeline( Returns: The fully resolved :class:`.DataPipeline`. """ - deserializer, old_data_source, preprocess, postprocess, serializer = None, None, None, None, None + deserializer, old_data_source, preprocess, postprocess, output = None, None, None, None, None # Datamodule datamodule = None @@ -681,32 +703,32 @@ def build_data_pipeline( old_data_source = getattr(datamodule.data_pipeline, "data_source", None) preprocess = getattr(datamodule.data_pipeline, "_preprocess_pipeline", None) postprocess = getattr(datamodule.data_pipeline, "_postprocess_pipeline", None) - serializer = getattr(datamodule.data_pipeline, "_serializer", None) + output = getattr(datamodule.data_pipeline, "_output", None) deserializer = getattr(datamodule.data_pipeline, "_deserializer", None) # Defaults / task attributes - deserializer, preprocess, postprocess, serializer = Task._resolve( + deserializer, preprocess, postprocess, output = Task._resolve( deserializer, preprocess, postprocess, - serializer, + output, self._deserializer, self._preprocess, self._postprocess, - self._serializer, + self._output, ) # Datapipeline if data_pipeline is not None: - deserializer, preprocess, postprocess, serializer = Task._resolve( + deserializer, preprocess, postprocess, output = Task._resolve( deserializer, preprocess, postprocess, - serializer, + output, getattr(data_pipeline, "_deserializer", None), getattr(data_pipeline, "_preprocess_pipeline", None), getattr(data_pipeline, "_postprocess_pipeline", None), - getattr(data_pipeline, "_serializer", None), + getattr(data_pipeline, "_output", None), ) data_source = data_source or old_data_source @@ -720,7 +742,7 @@ def build_data_pipeline( if deserializer is None or type(deserializer) is Deserializer: deserializer = getattr(preprocess, "deserializer", deserializer) - data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer) + data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, output) self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) @@ -744,15 +766,15 @@ def data_pipeline(self) -> DataPipeline: @torch.jit.unused @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: - self._deserializer, self._preprocess, self._postprocess, self.serializer = Task._resolve( + self._deserializer, self._preprocess, self._postprocess, self.output = Task._resolve( self._deserializer, self._preprocess, self._postprocess, - self._serializer, + self._output, getattr(data_pipeline, "_deserializer", None), getattr(data_pipeline, "_preprocess_pipeline", None), getattr(data_pipeline, "_postprocess_pipeline", None), - getattr(data_pipeline, "_serializer", None), + getattr(data_pipeline, "_output", None), ) # self._preprocess.state_dict() diff --git a/flash/core/regression.py b/flash/core/regression.py index edd351d02f..8823afe7d2 100644 --- a/flash/core/regression.py +++ b/flash/core/regression.py @@ -17,8 +17,8 @@ import torch.nn.functional as F import torchmetrics -from flash.core.data.process import Serializer from flash.core.model import Task +from flash.core.utilities.types import OUTPUT_TYPE class RegressionMixin: @@ -42,7 +42,7 @@ def __init__( *args, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + output: OUTPUT_TYPE = None, **kwargs, ) -> None: @@ -52,6 +52,6 @@ def __init__( *args, loss_fn=loss_fn, metrics=metrics, - serializer=serializer, + output=output, **kwargs, ) diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 72f9a57574..6a70edd41d 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -26,15 +26,15 @@ def deserialize(self, data: str) -> Any: # pragma: no cover class FlashOutputs(BaseType): def __init__( self, - serializer: Callable, + output: Callable, ): - self._serializer = serializer + self._output = output def serialize(self, outputs) -> Any: # pragma: no cover results = [] if isinstance(outputs, (list, torch.Tensor)): for output in outputs: - result = self._serializer(output) + result = self._output(output) if isinstance(result, Mapping): result = result[DefaultDataKeys.PREDS] results.append(result) @@ -64,7 +64,7 @@ def __init__(self, model): @expose( inputs={"inputs": FlashInputs(data_pipeline.deserialize_processor())}, - outputs={"outputs": FlashOutputs(data_pipeline.serialize_processor())}, + outputs={"outputs": FlashOutputs(data_pipeline.output_processor())}, ) def predict(self, inputs): with torch.no_grad(): diff --git a/flash/core/serve/types/image.py b/flash/core/serve/types/image.py index 82a82219ea..a20e5fb188 100644 --- a/flash/core/serve/types/image.py +++ b/flash/core/serve/types/image.py @@ -16,7 +16,7 @@ @dataclass(unsafe_hash=True) class Image(BaseType): - """Image serializer. + """Image output. Notes ----- diff --git a/flash/core/utilities/types.py b/flash/core/utilities/types.py index 8138db88b5..e7597d963a 100644 --- a/flash/core/utilities/types.py +++ b/flash/core/utilities/types.py @@ -3,7 +3,8 @@ from torch import nn from torchmetrics import Metric -from flash.core.data.process import Deserializer, Postprocess, Preprocess, Serializer +from flash.core.data.io.output import Output +from flash.core.data.process import Deserializer, Postprocess, Preprocess MODEL_TYPE = Optional[nn.Module] LOSS_FN_TYPE = Optional[Union[Callable, Mapping, Sequence]] @@ -15,4 +16,4 @@ DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]] PREPROCESS_TYPE = Optional[Preprocess] POSTPROCESS_TYPE = Optional[Postprocess] -SERIALIZER_TYPE = Optional[Union[Serializer, Mapping[str, Serializer]]] +OUTPUT_TYPE = Optional[Union[Output, Mapping[str, Output]]] diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 4d23f9aa3a..ff2599df95 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -19,7 +19,7 @@ from flash.core.classification import ClassificationAdapterTask, Labels from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.classification.adapters import TRAINING_STRATEGIES from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -60,7 +60,7 @@ def fn_resnet(pretrained: bool = True): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. multi_label: Whether the targets are multi-label or not. - serializer: A instance of :class:`~flash.core.data.process.Serializer` or a mapping consisting of such + output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such to use when serializing prediction outputs. training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn` for doing meta-learning research @@ -85,7 +85,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, training_strategy: Optional[str] = "default", training_strategy_kwargs: Optional[Dict[str, Any]] = None, ): @@ -137,7 +137,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, multi_label=multi_label, - serializer=serializer or Labels(multi_label=multi_label), + output=output or Labels(multi_label=multi_label), ) @classmethod diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 6a7cd8aff3..b0075a6956 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -14,9 +14,9 @@ from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.serialization import Preds +from flash.core.data.output import Preds from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.detection.backbones import OBJECT_DETECTION_HEADS @@ -42,7 +42,7 @@ class ObjectDetector(AdapterTask): pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. Has no effect for custom models. learning_rate: The learning rate to use for training - serializer: A instance of :class:`~flash.core.data.process.Serializer` or a mapping consisting of such + output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such to use when serializing prediction outputs. kwargs: additional kwargs nessesary for initializing the backbone task """ @@ -60,7 +60,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-3, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -80,7 +80,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - serializer=serializer or Preds(), + output=output or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/detection/serialization.py b/flash/image/detection/output.py similarity index 92% rename from flash/image/detection/serialization.py rename to flash/image/detection/output.py index 115f7d3118..1b56c734d2 100644 --- a/flash/image/detection/serialization.py +++ b/flash/image/detection/output.py @@ -16,7 +16,7 @@ from pytorch_lightning.utilities import rank_zero_warn from flash.core.data.data_source import DefaultDataKeys, LabelsState -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires Detections = None @@ -28,8 +28,8 @@ fo = None -class FiftyOneDetectionLabels(Serializer): - """A :class:`.Serializer` which converts model outputs to FiftyOne detection format. +class FiftyOneDetectionLabels(Output): + """A :class:`.Output` which converts model outputs to FiftyOne detection format. Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not @@ -55,9 +55,9 @@ def __init__( if labels is not None: self.set_state(LabelsState(labels)) - def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: + def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: if DefaultDataKeys.METADATA not in sample: - raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels serializer.") + raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels output.") labels = None if self._labels is not None: diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 042c417848..bfb2e52c18 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -17,7 +17,7 @@ import torch from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.utilities.imports import _FASTFACE_AVAILABLE @@ -26,8 +26,8 @@ LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, + OUTPUT_TYPE, PREPROCESS_TYPE, - SERIALIZER_TYPE, ) from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES from flash.image.face_detection.data import FaceDetectionPreprocess @@ -44,10 +44,10 @@ def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze(modules=pl_module.model.backbone, train_bn=self.train_bn) -class DetectionLabels(Serializer): - """A :class:`.Serializer` which extracts predictions from sample dict.""" +class DetectionLabels(Output): + """A :class:`.Output` which extracts predictions from sample dict.""" - def serialize(self, sample: Any) -> Dict[str, Any]: + def transform(self, sample: Any) -> Dict[str, Any]: return sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample @@ -78,7 +78,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-4, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, preprocess: PREPROCESS_TYPE = None, **kwargs: Any, ): @@ -96,7 +96,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - serializer=serializer or DetectionLabels(), + output=output or DetectionLabels(), preprocess=preprocess or FaceDetectionPreprocess(), ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index ae68668768..c947a8c4f6 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -17,9 +17,9 @@ from flash.core.adapter import AdapterTask from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.serialization import Preds +from flash.core.data.output import Preds from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS from flash.image.instance_segmentation.data import InstanceSegmentationPostProcess, InstanceSegmentationPreprocess @@ -62,7 +62,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -82,7 +82,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - serializer=serializer or Preds(), + output=output or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 306b334d12..d9a844dc0f 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -14,9 +14,9 @@ from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.serialization import Preds +from flash.core.data.output import Preds from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS @@ -59,7 +59,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -80,7 +80,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - serializer=serializer or Preds(), + output=output or Preds(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index d859964239..7819dd4589 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -44,7 +44,7 @@ ) from flash.core.utilities.stages import RunningStage from flash.image.data import ImageDeserializer, IMG_EXTENSIONS -from flash.image.segmentation.serialization import SegmentationLabels +from flash.image.segmentation.output import SegmentationLabels from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms SampleCollection = None diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index b0b293ad6b..a9589b20ad 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -29,12 +29,12 @@ LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, + OUTPUT_TYPE, POSTPROCESS_TYPE, - SERIALIZER_TYPE, ) from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS -from flash.image.segmentation.serialization import SegmentationLabels +from flash.image.segmentation.output import SegmentationLabels if _KORNIA_AVAILABLE: import kornia as K @@ -68,7 +68,7 @@ class SemanticSegmentation(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.IOU`. learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. postprocess: :class:`~flash.core.data.process.Postprocess` use for post processing samples. """ @@ -94,7 +94,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, postprocess: POSTPROCESS_TYPE = None, ) -> None: if metrics is None: @@ -114,7 +114,7 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - serializer=serializer or SegmentationLabels(), + output=output or SegmentationLabels(), postprocess=postprocess or self.postprocess_cls(), ) diff --git a/flash/image/segmentation/serialization.py b/flash/image/segmentation/output.py similarity index 89% rename from flash/image/segmentation/serialization.py rename to flash/image/segmentation/output.py index bad3655894..e8873a421d 100644 --- a/flash/image/segmentation/serialization.py +++ b/flash/image/segmentation/output.py @@ -18,7 +18,7 @@ import flash from flash.core.data.data_source import DefaultDataKeys, ImageLabelsMap -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, @@ -46,9 +46,9 @@ K = None -class SegmentationLabels(Serializer): - """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification per pixel - in the image for semantic segmentation tasks. +class SegmentationLabels(Output): + """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in + the image for semantic segmentation tasks. Args: labels_map: A dictionary that map the labels ids to pixel intensities. @@ -90,7 +90,7 @@ def _visualize(self, labels): plt.imshow(labels_vis) plt.show() - def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + def transform(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: preds = sample[DefaultDataKeys.PREDS] assert len(preds.shape) == 3, preds.shape labels = torch.argmax(preds, dim=-3) # HxW @@ -101,7 +101,7 @@ def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: class FiftyOneSegmentationLabels(SegmentationLabels): - """A :class:`.Serializer` which converts the model outputs to FiftyOne segmentation format. + """A :class:`.Output` which converts the model outputs to FiftyOne segmentation format. Args: labels_map: A dictionary that map the labels ids to pixel intensities. @@ -122,8 +122,8 @@ def __init__( self.return_filepath = return_filepath - def serialize(self, sample: Dict[str, torch.Tensor]) -> Union[Segmentation, Dict[str, Any]]: - labels = super().serialize(sample) + def transform(self, sample: Dict[str, torch.Tensor]) -> Union[Segmentation, Dict[str, Any]]: + labels = super().transform(sample) fo_predictions = fol.Segmentation(mask=np.array(labels)) if self.return_filepath: filepath = sample[DefaultDataKeys.METADATA]["filepath"] diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index a03575a64a..1ac19d005d 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -20,7 +20,7 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES if _IMAGE_AVAILABLE: @@ -61,7 +61,7 @@ class StyleTransfer(Task): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. """ backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES @@ -80,7 +80,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-3, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, ): self.save_hyperparameters(ignore="style_image") @@ -110,7 +110,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, - serializer=serializer, + output=output, ) self.perceptual_loss = perceptual_loss diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index b35604cae3..efe402909e 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -20,19 +20,19 @@ from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES __FILE_EXAMPLE__ = "pointcloud_detection" -class PointCloudObjectDetectorSerializer(Serializer): +class PointCloudObjectDetectorOutput(Output): pass @@ -53,7 +53,7 @@ class PointCloudObjectDetector(Task): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. lambda_loss_cls: The value to scale the loss classification. lambda_loss_bbox: The value to scale the bounding boxes loss. lambda_loss_dir: The value to scale the bounding boxes direction loss. @@ -73,7 +73,7 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, - serializer: SERIALIZER_TYPE = PointCloudObjectDetectorSerializer(), + output: OUTPUT_TYPE = PointCloudObjectDetectorOutput(), lambda_loss_cls: float = 1.0, lambda_loss_bbox: float = 1.0, lambda_loss_dir: float = 1.0, @@ -86,7 +86,7 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - serializer=serializer, + output=output, ) self.save_hyperparameters() diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index e8578b586c..227bf63e59 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -24,12 +24,12 @@ from flash.core.classification import ClassificationTask from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Serializer +from flash.core.data.io.output import Output from flash.core.data.states import CollateFn from flash.core.finetuning import BaseFinetuning from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES if _POINTCLOUD_AVAILABLE: @@ -63,7 +63,7 @@ def finetune_function( ) -class PointCloudSegmentationSerializer(Serializer): +class PointCloudSegmentationOutput(Output): pass @@ -84,7 +84,7 @@ class PointCloudSegmentation(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. """ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES @@ -103,7 +103,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: SERIALIZER_TYPE = PointCloudSegmentationSerializer(), + output: OUTPUT_TYPE = PointCloudSegmentationOutput(), ): import flash @@ -118,7 +118,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer, + output=output, ) self.save_hyperparameters() diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index a0c0cc7114..cec72473bd 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -19,7 +19,7 @@ from flash.core.classification import ClassificationTask, Probabilities from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE if _TABULAR_AVAILABLE: from pytorch_tabnet.tab_network import TabNet @@ -42,7 +42,7 @@ class TabularClassifier(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. **tabnet_kwargs: Optional additional arguments for the TabNet model, see `pytorch_tabnet `_. """ @@ -60,7 +60,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, **tabnet_kwargs, ): self.save_hyperparameters() @@ -83,7 +83,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer or Probabilities(), + output=output or Probabilities(), ) self.save_hyperparameters() diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 7710332670..d5090a1298 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -19,7 +19,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.regression import RegressionTask from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE if _TABULAR_AVAILABLE: from pytorch_tabnet.tab_network import TabNet @@ -56,7 +56,7 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, **tabnet_kwargs, ): self.save_hyperparameters() @@ -78,7 +78,7 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - serializer=serializer, + output=output, ) self.save_hyperparameters() diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 3350972567..804549d83e 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -19,7 +19,7 @@ from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.template.classification.backbones import TEMPLATE_BACKBONES @@ -40,7 +40,7 @@ class TemplateSKLearnClassifier(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs. """ backbones: FlashRegistry = TEMPLATE_BACKBONES @@ -57,7 +57,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, ): super().__init__( model=None, @@ -67,7 +67,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer or Labels(), + output=output or Labels(), ) self.save_hyperparameters() diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index dcf0e13bf3..da491f2026 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -22,7 +22,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES from flash.text.ort_callback import ORTCallback @@ -46,7 +46,7 @@ class TextClassifier(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. - serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ @@ -64,7 +64,7 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, enable_ort: bool = False, ): self.save_hyperparameters() @@ -84,7 +84,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer or Labels(multi_label=multi_label), + output=output or Labels(multi_label=multi_label), ) self.enable_ort = enable_ort self.model = self.backbones.get(backbone)(num_labels=num_classes) diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index f70c913f54..6ae0340355 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -32,7 +32,7 @@ from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.providers import _PYTORCHVIDEO -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") @@ -93,7 +93,7 @@ class VideoClassifier(ClassificationTask): head: either a `nn.Module` or a callable function that converts the features extrated from the backbone into class log probabilities (assuming default loss function). If `None`, will default to using a single linear layer. - serializer: A instance of :class:`~flash.core.data.process.Serializer` that determines how the output + output: A instance of :class:`~flash.core.data.io.output.Output` that determines how the output should be serialized e.g. convert the model output into the desired output format when predicting. """ @@ -113,7 +113,7 @@ def __init__( metrics: METRICS_TYPE = Accuracy(), learning_rate: float = 1e-3, head: Optional[Union[FunctionType, nn.Module]] = None, - serializer: SERIALIZER_TYPE = None, + output: OUTPUT_TYPE = None, ): super().__init__( model=None, @@ -122,7 +122,7 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - serializer=serializer or Labels(), + output=output or Labels(), ) self.save_hyperparameters() diff --git a/flash_examples/integrations/baal/image_classification_active_learning.py b/flash_examples/integrations/baal/image_classification_active_learning.py index add36785b0..d006ce9312 100644 --- a/flash_examples/integrations/baal/image_classification_active_learning.py +++ b/flash_examples/integrations/baal/image_classification_active_learning.py @@ -34,7 +34,7 @@ torch.nn.Dropout(p=0.1), torch.nn.Linear(512, datamodule.num_classes), ) -model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities()) +model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities()) # 3.1 Create the trainer diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py index 1d539b4eaf..f7fc0db39a 100644 --- a/flash_examples/integrations/fiftyone/image_classification.py +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -37,7 +37,7 @@ model = ImageClassifier( backbone="resnet18", num_classes=datamodule.num_classes, - serializer=Labels(), + output=Labels(), ) trainer = flash.Trainer( max_epochs=1, @@ -56,7 +56,7 @@ model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" ) -model.serializer = FiftyOneLabels(return_filepath=True) # output FiftyOne format +model.output = FiftyOneLabels(return_filepath=True) # output FiftyOne format predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py index b7c12f79ca..96ea5ffc51 100644 --- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -50,7 +50,7 @@ model = ImageClassifier( backbone="resnet18", num_classes=datamodule.num_classes, - serializer=Labels(), + output=Labels(), ) trainer = flash.Trainer( max_epochs=1, @@ -69,7 +69,7 @@ model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" ) -model.serializer = FiftyOneLabels(return_filepath=False) # output FiftyOne format +model.output = FiftyOneLabels(return_filepath=False) # output FiftyOne format datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset) predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/flash_examples/integrations/fiftyone/object_detection.py index efec712477..8a0450c51e 100644 --- a/flash_examples/integrations/fiftyone/object_detection.py +++ b/flash_examples/integrations/fiftyone/object_detection.py @@ -17,7 +17,7 @@ from flash.core.integrations.fiftyone import visualize from flash.core.utilities.imports import example_requires from flash.image import ObjectDetectionData, ObjectDetector -from flash.image.detection.serialization import FiftyOneDetectionLabels +from flash.image.detection.output import FiftyOneDetectionLabels example_requires("image") @@ -41,8 +41,8 @@ trainer = flash.Trainer(max_epochs=1) trainer.finetune(model, datamodule=datamodule, strategy="freeze") -# 4. Set the serializer and get some predictions -model.serializer = FiftyOneDetectionLabels(return_filepath=True) # output FiftyOne format +# 4. Set the output and get some predictions +model.output = FiftyOneDetectionLabels(return_filepath=True) # output FiftyOne format predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/flash_examples/integrations/labelstudio/image_classification.py index 41e3aa7332..637af63c08 100644 --- a/flash_examples/integrations/labelstudio/image_classification.py +++ b/flash_examples/integrations/labelstudio/image_classification.py @@ -31,7 +31,7 @@ # 4. Predict from checkpoint model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") -model.serializer = Labels() +model.output = Labels() predictions = model.predict( [ diff --git a/flash_examples/serve/semantic_segmentation/inference_server.py b/flash_examples/serve/semantic_segmentation/inference_server.py index 140b3c6c34..ea106da239 100644 --- a/flash_examples/serve/semantic_segmentation/inference_server.py +++ b/flash_examples/serve/semantic_segmentation/inference_server.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash.image import SemanticSegmentation -from flash.image.segmentation.serialization import SegmentationLabels +from flash.image.segmentation.output import SegmentationLabels model = SemanticSegmentation.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.5.2/semantic_segmentation_model.pt" ) -model.serializer = SegmentationLabels(visualize=False) +model.output = SegmentationLabels(visualize=False) model.serve() diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index 08975b6cfb..c9365d773b 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -17,5 +17,5 @@ model = TabularClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt" ) -model.serializer = Labels(["Did not survive", "Survived"]) +model.output = Labels(["Did not survive", "Survived"]) model.serve() diff --git a/tests/core/data/io/__init__.py b/tests/core/data/io/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py new file mode 100644 index 0000000000..1d18ad1315 --- /dev/null +++ b/tests/core/data/io/test_output.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import Mock + +import pytest +import torch +from torch.utils.data import DataLoader + +from flash.core.classification import Labels +from flash.core.data.data_pipeline import DataPipeline, DataPipelineState +from flash.core.data.data_source import LabelsState +from flash.core.data.io.output import Output, OutputMapping +from flash.core.data.process import DefaultPreprocess +from flash.core.data.properties import ProcessState +from flash.core.model import Task +from flash.core.trainer import Trainer + + +def test_output_enable_disable(): + """Tests that ``Output`` can be enabled and disabled correctly.""" + + my_output = Output() + + assert my_output.transform("test") == "test" + my_output.transform = Mock() + + my_output.disable() + assert my_output("test") == "test" + my_output.transform.assert_not_called() + + my_output.enable() + my_output("test") + my_output.transform.assert_called_once() + + +def test_saving_with_output(tmpdir): + checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") + + class CustomModel(Task): + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + + output = Labels(["a", "b"]) + model = CustomModel() + trainer = Trainer(fast_dev_run=True) + data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), output=output) + data_pipeline.initialize() + model.data_pipeline = data_pipeline + assert isinstance(model.preprocess, DefaultPreprocess) + dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) + trainer.fit(model, train_dataloader=dummy_data) + trainer.save_checkpoint(checkpoint_file) + model = CustomModel.load_from_checkpoint(checkpoint_file) + assert isinstance(model._data_pipeline_state, DataPipelineState) + assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) + + +def test_output_mapping(): + """Tests that ``OutputMapping`` correctly passes its inputs to the underlying outputs. + + Also checks that state is retrieved / loaded correctly. + """ + + output1 = Output() + output1.transform = Mock(return_value="test1") + + class output1State(ProcessState): + pass + + output2 = Output() + output2.transform = Mock(return_value="test2") + + class output2State(ProcessState): + pass + + output_mapping = OutputMapping({"key1": output1, "key2": output2}) + assert output_mapping({"key1": "output1", "key2": "output2"}) == {"key1": "test1", "key2": "test2"} + output1.transform.assert_called_once_with("output1") + output2.transform.assert_called_once_with("output2") + + with pytest.raises(ValueError, match="output must be a mapping"): + output_mapping("not a mapping") + + output1_state = output1State() + output2_state = output2State() + + output1.set_state(output1_state) + output2.set_state(output2_state) + + data_pipeline_state = DataPipelineState() + output_mapping.attach_data_pipeline_state(data_pipeline_state) + + assert output1._data_pipeline_state is data_pipeline_state + assert output2._data_pipeline_state is data_pipeline_state + + assert data_pipeline_state.get_state(output1State) is output1_state + assert data_pipeline_state.get_state(output2State) is output2_state diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index d317c0a9b5..c14cf35fcc 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -74,7 +74,7 @@ def test_postprocessor_str(): "\t(per_batch_transform): FuncModule(relu)\n" "\t(uncollate_fn): FuncModule(default_uncollate)\n" "\t(per_sample_transform): FuncModule(softmax)\n" - "\t(serializer): None" + "\t(output): None" ) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 68e48b546a..ce4df2e9fb 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -29,7 +29,8 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource -from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer +from flash.core.data.io.output import Output +from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.data.states import PerBatchTransformOnDevice, ToTensorTransform from flash.core.model import Task @@ -73,12 +74,12 @@ def test_data_pipeline_str(): data_source=cast(DataSource, "data_source"), preprocess=cast(Preprocess, "preprocess"), postprocess=cast(Postprocess, "postprocess"), - serializer=cast(Serializer, "serializer"), + output=cast(Output, "output"), deserializer=cast(Deserializer, "deserializer"), ) expected = "data_source=data_source, deserializer=deserializer, " - expected += "preprocess=preprocess, postprocess=postprocess, serializer=serializer" + expected += "preprocess=preprocess, postprocess=postprocess, output=output" assert str(data_pipeline) == (f"DataPipeline({expected})") diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py index 61ab591591..e792f83080 100644 --- a/tests/core/data/test_process.py +++ b/tests/core/data/test_process.py @@ -11,102 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest.mock import Mock import pytest import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader -from flash import Task, Trainer -from flash.core.classification import Labels, LabelsState from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess +from flash.core.data.data_pipeline import DefaultPreprocess from flash.core.data.data_source import DefaultDataSources -from flash.core.data.process import Serializer, SerializerMapping -from flash.core.data.properties import ProcessState - - -def test_serializer(): - """Tests that ``Serializer`` can be enabled and disabled correctly.""" - - my_serializer = Serializer() - - assert my_serializer.serialize("test") == "test" - my_serializer.serialize = Mock() - - my_serializer.disable() - assert my_serializer("test") == "test" - my_serializer.serialize.assert_not_called() - - my_serializer.enable() - my_serializer("test") - my_serializer.serialize.assert_called_once() - - -def test_serializer_mapping(): - """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. - - Also checks that state is retrieved / loaded correctly. - """ - - serializer1 = Serializer() - serializer1.serialize = Mock(return_value="test1") - - class Serializer1State(ProcessState): - pass - - serializer2 = Serializer() - serializer2.serialize = Mock(return_value="test2") - - class Serializer2State(ProcessState): - pass - - serializer_mapping = SerializerMapping({"key1": serializer1, "key2": serializer2}) - assert serializer_mapping({"key1": "serializer1", "key2": "serializer2"}) == {"key1": "test1", "key2": "test2"} - serializer1.serialize.assert_called_once_with("serializer1") - serializer2.serialize.assert_called_once_with("serializer2") - - with pytest.raises(ValueError, match="output must be a mapping"): - serializer_mapping("not a mapping") - - serializer1_state = Serializer1State() - serializer2_state = Serializer2State() - - serializer1.set_state(serializer1_state) - serializer2.set_state(serializer2_state) - - data_pipeline_state = DataPipelineState() - serializer_mapping.attach_data_pipeline_state(data_pipeline_state) - - assert serializer1._data_pipeline_state is data_pipeline_state - assert serializer2._data_pipeline_state is data_pipeline_state - - assert data_pipeline_state.get_state(Serializer1State) is serializer1_state - assert data_pipeline_state.get_state(Serializer2State) is serializer2_state - - -def test_saving_with_serializers(tmpdir): - checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") - - class CustomModel(Task): - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - - serializer = Labels(["a", "b"]) - model = CustomModel() - trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), serializer=serializer) - data_pipeline.initialize() - model.data_pipeline = data_pipeline - assert isinstance(model.preprocess, DefaultPreprocess) - dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) - trainer.fit(model, train_dataloader=dummy_data) - trainer.save_checkpoint(checkpoint_file) - model = CustomModel.load_from_checkpoint(checkpoint_file) - assert isinstance(model._data_pipeline_state, DataPipelineState) - assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) class CustomPreprocess(DefaultPreprocess): diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index 6cfa7a2c50..a7c6f6f38a 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -19,53 +19,53 @@ from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE -def test_classification_serializers(): +def test_classification_outputs(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes labels = ["class_1", "class_2", "class_3"] - assert torch.allclose(torch.tensor(Logits().serialize(example_output)), example_output) - assert torch.allclose(torch.tensor(Probabilities().serialize(example_output)), torch.softmax(example_output, -1)) - assert Classes().serialize(example_output) == 2 - assert Labels(labels).serialize(example_output) == "class_3" + assert torch.allclose(torch.tensor(Logits().transform(example_output)), example_output) + assert torch.allclose(torch.tensor(Probabilities().transform(example_output)), torch.softmax(example_output, -1)) + assert Classes().transform(example_output) == 2 + assert Labels(labels).transform(example_output) == "class_3" -def test_classification_serializers_multi_label(): +def test_classification_outputs_multi_label(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes labels = ["class_1", "class_2", "class_3"] - assert torch.allclose(torch.tensor(Logits(multi_label=True).serialize(example_output)), example_output) + assert torch.allclose(torch.tensor(Logits(multi_label=True).transform(example_output)), example_output) assert torch.allclose( - torch.tensor(Probabilities(multi_label=True).serialize(example_output)), + torch.tensor(Probabilities(multi_label=True).transform(example_output)), torch.sigmoid(example_output), ) - assert Classes(multi_label=True).serialize(example_output) == [1, 2] - assert Labels(labels, multi_label=True).serialize(example_output) == ["class_2", "class_3"] + assert Classes(multi_label=True).transform(example_output) == [1, 2] + assert Labels(labels, multi_label=True).transform(example_output) == ["class_2", "class_3"] @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") -def test_classification_serializers_fiftyone(): +def test_classification_outputs_fiftyone(): logits = torch.tensor([-0.1, 0.2, 0.3]) example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes labels = ["class_1", "class_2", "class_3"] - predictions = FiftyOneLabels(return_filepath=True).serialize(example_output) + predictions = FiftyOneLabels(return_filepath=True).transform(example_output) assert predictions["predictions"].label == "2" assert predictions["filepath"] == "something" - predictions = FiftyOneLabels(labels, return_filepath=True).serialize(example_output) + predictions = FiftyOneLabels(labels, return_filepath=True).transform(example_output) assert predictions["predictions"].label == "class_3" assert predictions["filepath"] == "something" - predictions = FiftyOneLabels(store_logits=True).serialize(example_output) + predictions = FiftyOneLabels(store_logits=True).transform(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1]) assert predictions.label == "2" - predictions = FiftyOneLabels(labels, store_logits=True).serialize(example_output) + predictions = FiftyOneLabels(labels, store_logits=True).transform(example_output) assert predictions.label == "class_3" - predictions = FiftyOneLabels(store_logits=True, multi_label=True).serialize(example_output) + predictions = FiftyOneLabels(store_logits=True, multi_label=True).transform(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) assert [c.label for c in predictions.classifications] == ["1", "2"] - predictions = FiftyOneLabels(labels, multi_label=True).serialize(example_output) + predictions = FiftyOneLabels(labels, multi_label=True).transform(example_output) assert [c.label for c in predictions.classifications] == ["class_2", "class_3"] diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 725974c595..1892e8bcbe 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -92,7 +92,7 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s ) model = ImageClassifier( - backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities() + backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, output=Probabilities() ) trainer = flash.Trainer(max_epochs=3) active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) @@ -144,7 +144,7 @@ def test_no_validation_loop(simple_datamodule): ) model = ImageClassifier( - backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities() + backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, output=Probabilities() ) trainer = flash.Trainer(max_epochs=3) active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 7dc49a3abc..96da2f4a11 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -104,7 +104,7 @@ def test_multilabel(tmpdir): num_classes = 4 ds = DummyMultiLabelDataset(num_classes) - model = ImageClassifier(num_classes, multi_label=True, serializer=Probabilities(multi_label=True)) + model = ImageClassifier(num_classes, multi_label=True, output=Probabilities(multi_label=True)) train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_output.py similarity index 89% rename from tests/image/detection/test_serialization.py rename to tests/image/detection/test_output.py index fcad6e5fe7..9023106c02 100644 --- a/tests/image/detection/test_serialization.py +++ b/tests/image/detection/test_output.py @@ -4,7 +4,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE -from flash.image.detection.serialization import FiftyOneDetectionLabels +from flash.image.detection.output import FiftyOneDetectionLabels @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -42,22 +42,22 @@ def test_serialize_fiftyone(): }, } - detections = serial.serialize(sample) + detections = serial.transform(sample) assert len(detections.detections) == 1 np.testing.assert_array_almost_equal(detections.detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections.detections[0].confidence == 0.5 assert detections.detections[0].label == "0" - detections = filepath_serial.serialize(sample) + detections = filepath_serial.transform(sample) assert len(detections["predictions"].detections) == 1 np.testing.assert_array_almost_equal(detections["predictions"].detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections["predictions"].detections[0].confidence == 0.5 assert detections["filepath"] == "something" - detections = threshold_serial.serialize(sample) + detections = threshold_serial.transform(sample) assert len(detections.detections) == 0 - detections = labels_serial.serialize(sample) + detections = labels_serial.transform(sample) assert len(detections.detections) == 1 np.testing.assert_array_almost_equal(detections.detections[0].bounding_box, [0.2, 0.3, 0.2, 0.2]) assert detections.detections[0].confidence == 0.5 diff --git a/tests/image/segmentation/test_serialization.py b/tests/image/segmentation/test_output.py similarity index 89% rename from tests/image/segmentation/test_serialization.py rename to tests/image/segmentation/test_output.py index 0e7477348a..ad06cd2bd8 100644 --- a/tests/image/segmentation/test_serialization.py +++ b/tests/image/segmentation/test_output.py @@ -16,7 +16,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE -from flash.image.segmentation.serialization import FiftyOneSegmentationLabels, SegmentationLabels +from flash.image.segmentation.output import FiftyOneSegmentationLabels, SegmentationLabels from tests.helpers.utils import _IMAGE_TESTING @@ -36,11 +36,11 @@ def test_exception(): with pytest.raises(Exception): sample = torch.zeros(1, 5, 2, 3) - serial.serialize(sample) + serial.transform(sample) with pytest.raises(Exception): sample = torch.zeros(2, 3) - serial.serialize(sample) + serial.transform(sample) @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") @staticmethod @@ -51,7 +51,7 @@ def test_serialize(): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.serialize({DefaultDataKeys.PREDS: sample}) + classes = serial.transform({DefaultDataKeys.PREDS: sample}) assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 @@ -71,11 +71,11 @@ def test_serialize_fiftyone(): DefaultDataKeys.METADATA: {"filepath": "something"}, } - segmentation = serial.serialize(sample) + segmentation = serial.transform(sample) assert segmentation.mask[1, 2] == 1 assert segmentation.mask[0, 1] == 3 - segmentation = filepath_serial.serialize(sample) + segmentation = filepath_serial.transform(sample) assert segmentation["predictions"].mask[1, 2] == 1 assert segmentation["predictions"].mask[0, 1] == 3 assert segmentation["filepath"] == "something"