Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Clean Output (#939)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 8, 2021
1 parent 4a242e6 commit 8de809e
Show file tree
Hide file tree
Showing 17 changed files with 28 additions and 113 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))

- Removed `Output.enable` and `Output.disable` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))


## [0.5.2] - 2021-11-05

Expand Down
1 change: 0 additions & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ _________________________
:template: classtemplate.rst

~flash.core.data.io.output.Output
~flash.core.data.io.output.OutputMapping

flash.core.data.auto_dataset
____________________________
Expand Down
2 changes: 1 addition & 1 deletion flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SpeechRecognition(Task):
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
Expand Down
38 changes: 2 additions & 36 deletions flash/core/data/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping
from typing import Any

import torch

import flash
from flash.core.data.properties import Properties
from flash.core.data.utils import convert_to_modules

Expand All @@ -24,18 +23,6 @@ class Output(Properties):
"""An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which
is used to convert the model output into the desired output format when predicting."""

def __init__(self):
super().__init__()
self._is_enabled = True

def enable(self):
"""Enable output transformation."""
self._is_enabled = True

def disable(self):
"""Disable output transformation."""
self._is_enabled = False

@staticmethod
def transform(sample: Any) -> Any:
"""Convert the given sample into the desired output format.
Expand All @@ -49,28 +36,7 @@ def transform(sample: Any) -> Any:
return sample

def __call__(self, sample: Any) -> Any:
if self._is_enabled:
return self.transform(sample)
return sample


class OutputMapping(Output):
"""If the model output is a dictionary, then the :class:`.OutputMapping` enables each entry in the dictionary
to be passed to it's own :class:`.Output`."""

def __init__(self, outputs: Mapping[str, Output]):
super().__init__()

self._outputs = outputs

def transform(self, sample: Any) -> Any:
if isinstance(sample, Mapping):
return {key: output.transform(sample[key]) for key, output in self._outputs.items()}
raise ValueError("The model output must be a mapping when using an OutputMapping.")

def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"):
for output in self._outputs.values():
output.attach_data_pipeline_state(data_pipeline_state)
return self.transform(sample)


class _OutputProcessor(torch.nn.Module):
Expand Down
11 changes: 4 additions & 7 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.data_source import DataSource
from flash.core.data.io.output import Output, OutputMapping
from flash.core.data.io.output import Output
from flash.core.data.process import Deserializer, DeserializerMapping, Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY
Expand Down Expand Up @@ -319,8 +319,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check
deserialize the input
preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task.
output: Either a single :class:`~flash.core.data.io.output.Output` or a mapping of these to
serialize the output e.g. convert the model output into the desired output format when predicting.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY
Expand Down Expand Up @@ -630,9 +629,7 @@ def output(self) -> Optional[Output]:

@torch.jit.unused
@output.setter
def output(self, output: Union[Output, Mapping[str, Output]]):
if isinstance(output, Mapping):
output = OutputMapping(output)
def output(self, output: Output):
self._output = output

@torch.jit.unused
Expand Down Expand Up @@ -662,7 +659,7 @@ def serializer(self) -> Optional[Output]:
"It will be removed in v%(remove_in)s.",
stream=functools.partial(warn, category=FutureWarning),
)
def serializer(self, serializer: Union[Output, Mapping[str, Output]]):
def serializer(self, serializer: Output):
self.output = serializer

def build_data_pipeline(
Expand Down
2 changes: 1 addition & 1 deletion flash/core/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]]
PREPROCESS_TYPE = Optional[Preprocess]
POSTPROCESS_TYPE = Optional[Postprocess]
OUTPUT_TYPE = Optional[Union[Output, Mapping[str, Output]]]
OUTPUT_TYPE = Optional[Output]
3 changes: 1 addition & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def fn_resnet(pretrained: bool = True):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the targets are multi-label or not.
output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such
to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn`
for doing meta-learning research
training_strategy_kwargs: Additional kwargs for setting the training strategy
Expand Down
3 changes: 1 addition & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ class ObjectDetector(AdapterTask):
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models.
learning_rate: The learning rate to use for training
output: A instance of :class:`~flash.core.data.io.output.Output` or a mapping consisting of such
to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
kwargs: additional kwargs nessesary for initializing the backbone task
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SemanticSegmentation(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.IOU`.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
postprocess: :class:`~flash.core.data.process.Postprocess` use for post processing samples.
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PointCloudObjectDetector(Task):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
lambda_loss_cls: The value to scale the loss classification.
lambda_loss_bbox: The value to scale the bounding boxes loss.
lambda_loss_dir: The value to scale the bounding boxes direction loss.
Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PointCloudSegmentation(ClassificationTask):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TabularClassifier(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TabularRegressor(RegressionTask):
`metric(preds,target)` and return a single scalar tensor.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""
Expand Down
2 changes: 1 addition & 1 deletion flash/template/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TemplateSKLearnClassifier(ClassificationTask):
by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
learning_rate: The learning rate for the optimizer.
multi_label: If ``True``, this will be treated as a multi-label classification problem.
output: The :class:`~flash.core.data.io.output.Output` to use for prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = TEMPLATE_BACKBONES
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TextClassifier(ClassificationTask):
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

Expand Down
3 changes: 1 addition & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ class VideoClassifier(ClassificationTask):
head: either a `nn.Module` or a callable function that converts the features extrated from the backbone
into class log probabilities (assuming default loss function). If `None`, will default to using
a single linear layer.
output: A instance of :class:`~flash.core.data.io.output.Output` that determines how the output
should be serialized e.g. convert the model output into the desired output format when predicting.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
"""

backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES
Expand Down
58 changes: 4 additions & 54 deletions tests/core/data/io/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,25 @@
import os
from unittest.mock import Mock

import pytest
import torch
from torch.utils.data import DataLoader

from flash.core.classification import Labels
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.data_source import LabelsState
from flash.core.data.io.output import Output, OutputMapping
from flash.core.data.io.output import Output
from flash.core.data.process import DefaultPreprocess
from flash.core.data.properties import ProcessState
from flash.core.model import Task
from flash.core.trainer import Trainer


def test_output_enable_disable():
"""Tests that ``Output`` can be enabled and disabled correctly."""

def test_output():
"""Tests basic ``Output`` methods."""
my_output = Output()

assert my_output.transform("test") == "test"
my_output.transform = Mock()

my_output.disable()
assert my_output("test") == "test"
my_output.transform.assert_not_called()

my_output.enable()
my_output.transform = Mock()
my_output("test")
my_output.transform.assert_called_once()

Expand All @@ -65,45 +57,3 @@ def __init__(self):
model = CustomModel.load_from_checkpoint(checkpoint_file)
assert isinstance(model._data_pipeline_state, DataPipelineState)
assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"])


def test_output_mapping():
"""Tests that ``OutputMapping`` correctly passes its inputs to the underlying outputs.
Also checks that state is retrieved / loaded correctly.
"""

output1 = Output()
output1.transform = Mock(return_value="test1")

class output1State(ProcessState):
pass

output2 = Output()
output2.transform = Mock(return_value="test2")

class output2State(ProcessState):
pass

output_mapping = OutputMapping({"key1": output1, "key2": output2})
assert output_mapping({"key1": "output1", "key2": "output2"}) == {"key1": "test1", "key2": "test2"}
output1.transform.assert_called_once_with("output1")
output2.transform.assert_called_once_with("output2")

with pytest.raises(ValueError, match="output must be a mapping"):
output_mapping("not a mapping")

output1_state = output1State()
output2_state = output2State()

output1.set_state(output1_state)
output2.set_state(output2_state)

data_pipeline_state = DataPipelineState()
output_mapping.attach_data_pipeline_state(data_pipeline_state)

assert output1._data_pipeline_state is data_pipeline_state
assert output2._data_pipeline_state is data_pipeline_state

assert data_pipeline_state.get_state(output1State) is output1_state
assert data_pipeline_state.get_state(output2State) is output2_state

0 comments on commit 8de809e

Please sign in to comment.