Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add support for Iterable auto dataset + resolve a bug for Preprocess …
Browse files Browse the repository at this point in the history
…Transforms. (#227)

* update

* update

* Update flash/vision/video/classification/data.py

* update

* Update flash/vision/video/classification/model.py

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

* update

* update

* typo

* update

* update

* resolve some internal bugs

* update on comments

* move files

* update

* update

* update

* filter for 3.6

* update on comments

* update

* update

* update

* clean auto dataset

* typo

* update

* update on comments:

* add doc

* remove backbone section

* update

* update

* update

* update

* map to None

* update

* update

* update on comments

* update script

* update on comments

* drop video integration

* resolve bug

* remove video docs

* remove pytorchvideo

* update

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
tchaton and kaushikb11 committed Apr 19, 2021
1 parent 42cc20a commit 781fa98
Show file tree
Hide file tree
Showing 14 changed files with 368 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Image Classification
********
The task
********
The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that desecribes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc. For example, we can train the image classifier task on images of ants and it will learn to predict the probability that an image contains an ant.
The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc. For example, we can train the image classifier task on images of ants and it will learn to predict the probability that an image contains an ant.

------

Expand Down
4 changes: 4 additions & 0 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any

import torch
import torch.nn.functional as F

from flash.core.model import Task
from flash.data.process import Postprocess
Expand All @@ -29,3 +30,6 @@ class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return F.softmax(x, -1)
19 changes: 15 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
output = {"y_hat": y_hat}
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(y_hat)
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
Expand All @@ -111,6 +112,9 @@ def step(self, batch: Any, batch_idx: int) -> Any:
output["y"] = y
return output

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return x

def forward(self, x: Any) -> Any:
return self.model(x)

Expand Down Expand Up @@ -172,10 +176,10 @@ def configure_finetune_callback(self) -> List[Callback]:

@staticmethod
def _resolve(
old_preprocess: Optional[Preprocess],
old_postprocess: Optional[Postprocess],
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
old_preprocess: Optional[Preprocess],
old_postprocess: Optional[Postprocess],
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise.
Expand Down Expand Up @@ -308,3 +312,10 @@ def available_backbones(cls) -> List[str]:
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_models(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
if registry is None:
return []
return registry.available_keys()
49 changes: 42 additions & 7 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from inspect import signature
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING
from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warning_utils import rank_zero_warn
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset

from flash.data.callback import ControlFlow
from flash.data.process import Preprocess
Expand All @@ -27,13 +27,13 @@
from flash.data.data_pipeline import DataPipeline


class AutoDataset(Dataset):
class BaseAutoDataset:

DATASET_KEY = "dataset"
"""
This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions.
``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage``
is provided and ``load_sample`` within ``__getitem__`` function.
is provided and ``load_sample`` within ``__getitem__``.
"""

def __init__(
Expand Down Expand Up @@ -122,10 +122,19 @@ def _setup(self, stage: Optional[RunningStage]) -> None:
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
with self._load_data_context:
self.preprocessed_data = self._call_load_data(self.data)
self.setup()
self._load_data_called = True

def setup(self):
raise NotImplementedError


class AutoDataset(BaseAutoDataset, Dataset):

def setup(self):
with self._load_data_context:
self.preprocessed_data = self._call_load_data(self.data)

def __getitem__(self, index: int) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
Expand All @@ -141,3 +150,29 @@ def __len__(self) -> int:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.")
return len(self.preprocessed_data)


class IterableAutoDataset(BaseAutoDataset, IterableDataset):

def setup(self):
with self._load_data_context:
self.dataset = self._call_load_data(self.data)
self.dataset_iter = None

def __iter__(self):
self.dataset_iter = iter(self.dataset)
return self

def __next__(self) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")

data = next(self.dataset_iter)

if self.load_sample:
with self._load_sample_context:
data: Any = self._call_load_sample(data)
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return data
73 changes: 57 additions & 16 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import IterableDataset, Subset

from flash.data.auto_dataset import AutoDataset
from flash.data.auto_dataset import BaseAutoDataset, IterableAutoDataset
from flash.data.base_viz import BaseVisualization
from flash.data.callback import BaseDataFetcher
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess
Expand Down Expand Up @@ -212,15 +212,16 @@ def set_running_stages(self):
self.set_dataset_attribute(self._predict_ds, 'running_stage', RunningStage.PREDICTING)

def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]:
if isinstance(dataset, AutoDataset):
if isinstance(dataset, BaseAutoDataset):
return self.data_pipeline.worker_preprocessor(running_stage)

def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))
return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=True,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
Expand Down Expand Up @@ -249,10 +250,13 @@ def _test_dataloader(self) -> DataLoader:

def _predict_dataloader(self) -> DataLoader:
predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds
if isinstance(predict_ds, IterableAutoDataset):
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)
return DataLoader(
predict_ds,
batch_size=min(self.batch_size,
len(predict_ds) if len(predict_ds) > 0 else 1),
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
Expand All @@ -263,6 +267,13 @@ def generate_auto_dataset(self, *args, **kwargs):
return None
return self.data_pipeline._generate_auto_dataset(*args, **kwargs)

@property
def num_classes(self) -> Optional[int]:
return (
getattr(self.train_dataset, "num_classes", None) or getattr(self.val_dataset, "num_classes", None)
or getattr(self.test_dataset, "num_classes", None)
)

@property
def preprocess(self) -> Preprocess:
return self._preprocess or self.preprocess_cls()
Expand Down Expand Up @@ -292,9 +303,10 @@ def autogenerate_dataset(
whole_data_load_fn: Optional[Callable] = None,
per_sample_load_fn: Optional[Callable] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> AutoDataset:
use_iterable_auto_dataset: bool = False,
) -> BaseAutoDataset:
"""
This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided
This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided
or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly
"""

Expand All @@ -309,7 +321,11 @@ def autogenerate_dataset(
cls.preprocess_cls,
DataPipeline._resolve_function_hierarchy('load_sample', cls.preprocess_cls, running_stage, Preprocess)
)
return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)
if use_iterable_auto_dataset:
return IterableAutoDataset(
data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage
)
return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)

@staticmethod
def train_val_test_split(
Expand Down Expand Up @@ -379,15 +395,27 @@ def _generate_dataset_if_possible(
running_stage: RunningStage,
whole_data_load_fn: Optional[Callable] = None,
per_sample_load_fn: Optional[Callable] = None,
data_pipeline: Optional[DataPipeline] = None
) -> Optional[AutoDataset]:
data_pipeline: Optional[DataPipeline] = None,
use_iterable_auto_dataset: bool = False,
) -> Optional[BaseAutoDataset]:
if data is None:
return

if data_pipeline:
return data_pipeline._generate_auto_dataset(data, running_stage=running_stage)
return data_pipeline._generate_auto_dataset(
data,
running_stage=running_stage,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)

return cls.autogenerate_dataset(data, running_stage, whole_data_load_fn, per_sample_load_fn, data_pipeline)
return cls.autogenerate_dataset(
data,
running_stage,
whole_data_load_fn,
per_sample_load_fn,
data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)

@classmethod
def from_load_data_inputs(
Expand All @@ -398,6 +426,7 @@ def from_load_data_inputs(
predict_load_data_input: Optional[Any] = None,
preprocess: Optional[Preprocess] = None,
postprocess: Optional[Postprocess] = None,
use_iterable_auto_dataset: bool = False,
**kwargs,
) -> 'DataModule':
"""
Expand Down Expand Up @@ -429,16 +458,28 @@ def from_load_data_inputs(
data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline)

train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
train_load_data_input,
running_stage=RunningStage.TRAINING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)
val_dataset = cls._generate_dataset_if_possible(
val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
val_load_data_input,
running_stage=RunningStage.VALIDATING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)
test_dataset = cls._generate_dataset_if_possible(
test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline
test_load_data_input,
running_stage=RunningStage.TESTING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)
predict_dataset = cls._generate_dataset_if_possible(
predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline
predict_load_data_input,
running_stage=RunningStage.PREDICTING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset,
)
datamodule = cls(
train_dataset=train_dataset,
Expand Down
Loading

0 comments on commit 781fa98

Please sign in to comment.