This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor image inputs and update to new input object (#997)
Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
- Loading branch information
1 parent
ce18c08
commit b8085f0
Showing
35 changed files
with
1,325 additions
and
452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from dataclasses import dataclass | ||
from functools import lru_cache | ||
from typing import Any, List, Optional, Sequence | ||
|
||
from flash.core.data.io.input_base import Input | ||
from flash.core.data.properties import ProcessState | ||
from flash.core.data.utilities.classification import ( | ||
get_target_details, | ||
get_target_formatter, | ||
get_target_mode, | ||
TargetFormatter, | ||
) | ||
|
||
|
||
@dataclass(unsafe_hash=True, frozen=True) | ||
class ClassificationState(ProcessState): | ||
"""A :class:`~flash.core.data.properties.ProcessState` containing ``labels`` (a mapping from class index to | ||
label) and ``num_classes``.""" | ||
|
||
labels: Optional[Sequence[str]] | ||
num_classes: Optional[int] = None | ||
|
||
|
||
class ClassificationInput(Input): | ||
"""The ``ClassificationInput`` class provides utility methods for handling classification targets. | ||
:class:`~flash.core.data.io.input_base.Input` objects that extend ``ClassificationInput`` should do the following: | ||
* In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the | ||
targets and store metadata like ``labels`` and ``num_classes``. | ||
* In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our | ||
tasks. | ||
""" | ||
|
||
@property | ||
@lru_cache(maxsize=None) | ||
def target_formatter(self) -> TargetFormatter: | ||
"""Get the :class:`~flash.core.data.utiltiies.classification.TargetFormatter` to use when formatting | ||
targets. | ||
This property uses ``functools.lru_cache`` so that we only instantiate the formatter once. | ||
""" | ||
classification_state = self.get_state(ClassificationState) | ||
return get_target_formatter(self.target_mode, classification_state.labels, classification_state.num_classes) | ||
|
||
def load_target_metadata(self, targets: List[Any]) -> None: | ||
"""Determine the target format and store the ``labels`` and ``num_classes``. | ||
Args: | ||
targets: The list of targets. | ||
""" | ||
self.target_mode = get_target_mode(targets) | ||
self.multi_label = self.target_mode.multi_label | ||
if self.training: | ||
self.labels, self.num_classes = get_target_details(targets, self.target_mode) | ||
self.set_state(ClassificationState(self.labels, self.num_classes)) | ||
|
||
def format_target(self, target: Any) -> Any: | ||
"""Format a single target according to the previously computed target format and metadata. | ||
Args: | ||
target: The target to format. | ||
Returns: | ||
The formatted target. | ||
""" | ||
return self.target_formatter(target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.