diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 1f833ed1f7..631020faec 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type, Union import pandas as pd import torch @@ -41,6 +41,8 @@ VideoClassificationFilesInput, VideoClassificationFoldersInput, VideoClassificationPathsPredictInput, + VideoClassificationTensorsInput, + VideoClassificationTensorsPredictInput, ) from flash.video.classification.input_transform import VideoClassificationInputTransform @@ -63,6 +65,7 @@ "VideoClassificationData.from_folders", "VideoClassificationData.from_data_frame", "VideoClassificationData.from_csv", + "VideoClassificationData.from_tensors", ] if not _VIDEO_EXTRAS_TESTING: __doctest_skip__ += ["VideoClassificationData.from_fiftyone"] @@ -395,7 +398,6 @@ def from_data_frame( predict_data_frame: Optional[pd.DataFrame] = None, predict_videos_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -404,6 +406,7 @@ def from_data_frame( decoder: str = "pyav", input_cls: Type[Input] = VideoClassificationDataFrameInput, predict_input_cls: Type[Input] = VideoClassificationDataFramePredictInput, + target_formatter: Optional[TargetFormatter] = None, transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -566,6 +569,122 @@ def from_data_frame( **data_module_kwargs, ) + @classmethod + def from_tensors( + cls, + train_data: Optional[Union[Collection[torch.Tensor], torch.Tensor]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Union[Collection[torch.Tensor], torch.Tensor]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[torch.Tensor]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Union[Collection[torch.Tensor], torch.Tensor]] = None, + target_formatter: Optional[TargetFormatter] = None, + video_sampler: Type[Sampler] = torch.utils.data.SequentialSampler, + input_cls: Type[Input] = VideoClassificationTensorsInput, + predict_input_cls: Type[Input] = VideoClassificationTensorsPredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "VideoClassificationData": + """Load the :class:`~flash.video.classification.data.VideoClassificationData` from a dictionary containing + PyTorch tensors representing input video frames and their corresponding targets. + + Input tensor(s) will be extracted from the ``input_field`` in the ``dict``. + The targets will be extracted from the ``target_fields`` in the ``dict`` and can be in any of our + :ref:`supported classification target formats `. + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The torch tensor or list of tensors to use when training. + train_targets: The list of targets to use when training. + val_data: The torch tensor or list of tensors to use when validating. + val_targets: The list of targets to use when validating. + test_data: The torch tensor or list of tensors to use when testing. + test_targets: The list of targets to use when testing. + predict_data: The torch tensor or list of tensors to use when predicting. + train_data: A torch tensor or list of tensors to use when training. + train_targets: The list of targets to use when training. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + video_sampler: Sampler for the internal video container. This defines the order tensors are used and, + if necessary, the distributed split. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.video.classification.data.VideoClassificationData`. + + Examples + ________ + + .. doctest:: + + >>> import torch + >>> from flash import Trainer + >>> from flash.video import VideoClassifier, VideoClassificationData + >>> frame = torch.randint(low=0, high=255, size=(3, 5, 10, 10), dtype=torch.uint8, device="cpu") + >>> datamodule = VideoClassificationData.from_tensors( + ... train_data=[frame, frame, frame], + ... train_targets=["fruit", "vegetable", "fruit"], + ... val_data=[frame, frame], + ... val_targets=["vegetable", "fruit"], + ... predict_data=[frame], + ... batch_size=1, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['fruit', 'vegetable'] + >>> model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> del frame + """ + + train_input = input_cls( + RunningStage.TRAINING, + train_data, + train_targets, + video_sampler=video_sampler, + target_formatter=target_formatter, + ) + target_formatter = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls( + RunningStage.VALIDATING, + val_data, + val_targets, + video_sampler=video_sampler, + target_formatter=target_formatter, + ), + input_cls( + RunningStage.TESTING, + test_data, + test_targets, + video_sampler=video_sampler, + target_formatter=target_formatter, + ), + predict_input_cls(RunningStage.PREDICTING, predict_data), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + @classmethod def from_csv( cls, diff --git a/flash/video/classification/input.py b/flash/video/classification/input.py index b149587e2e..96b3d3ff45 100644 --- a/flash/video/classification/input.py +++ b/flash/video/classification/input.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union import pandas as pd import torch @@ -21,7 +21,7 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input, IterableInput -from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter +from flash.core.data.utilities.classification import _is_list_like, MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.data_frame import resolve_files, resolve_targets from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE @@ -40,8 +40,17 @@ from pytorchvideo.data.encoded_video import EncodedVideo from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths + + from flash.video.classification.utils import LabeledVideoTensorDataset + else: - ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None + ClipSampler, LabeledVideoDataset, LabeledVideoTensorDataset, EncodedVideo, ApplyTransformToKey = ( + None, + None, + None, + None, + None, + ) def _make_clip_sampler( @@ -87,6 +96,43 @@ def load_sample(self, sample): return sample +class VideoClassificationTensorsBaseInput(IterableInput, ClassificationInputMixin): + def load_data( + self, + inputs: Optional[Union[Collection[torch.Tensor], torch.Tensor]], + targets: Union[List[Any], Any], + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + target_formatter: Optional[TargetFormatter] = None, + ) -> "LabeledVideoTensorDataset": + if isinstance(inputs, torch.Tensor): + # In case of (number of videos x CTHW) format + if inputs.ndim == 5: + inputs = list(inputs) + elif inputs.ndim == 4: + inputs = [inputs] + else: + raise ValueError( + f"Got dimension of the input tensor: {inputs.ndim}" + " for stack of tensors - dimension should be 5 or for a single tensor, dimension should be 4.", + ) + elif not _is_list_like(inputs): + raise TypeError(f"Expected either a list/tuple of torch.Tensor or torch.Tensor, but got: {type(inputs)}.") + + # Note: We take whatever is the shortest out of inputs and targets + dataset = LabeledVideoTensorDataset(list(zip(inputs, targets)), video_sampler=video_sampler) + if not self.predicting: + self.load_target_metadata( + [sample[1] for sample in dataset._labeled_videos], target_formatter=target_formatter + ) + return dataset + + def load_sample(self, sample): + sample["label"] = self.format_target(sample["label"]) + sample[DataKeys.INPUT] = sample.pop("video") + sample[DataKeys.TARGET] = sample.pop("label") + return sample + + class VideoClassificationFoldersInput(VideoClassificationInput): def load_data( self, @@ -178,6 +224,34 @@ def load_data( return result +class VideoClassificationTensorsInput(VideoClassificationTensorsBaseInput): + labels: list + + def load_data( + self, + tensors: Any, + targets: Optional[List[Any]] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + target_formatter: Optional[TargetFormatter] = None, + ) -> "LabeledVideoTensorDataset": + result = super().load_data( + tensors, + targets, + video_sampler=video_sampler, + target_formatter=target_formatter, + ) + + # If we had binary multi-class targets then we also know the labels (column names) + if ( + self.training + and isinstance(self.target_formatter, MultiBinaryTargetFormatter) + and isinstance(targets, List) + ): + self.labels = targets + + return result + + class VideoClassificationCSVInput(VideoClassificationDataFrameInput): def load_data( self, @@ -316,6 +390,30 @@ def predict_load_data( ) +class VideoClassificationTensorsPredictInput(Input): + def predict_load_data(self, data: Union[torch.Tensor, List[Any], Any]): + if _is_list_like(data): + return data + else: + if not isinstance(data, torch.Tensor): + raise TypeError(f"Expected either a list/tuple of torch.Tensor or torch.Tensor, but got: {type(data)}.") + if data.ndim == 5: + return list(data) + elif data.ndim == 4: + return [data] + else: + raise ValueError( + f"Got dimension of the input tensor: {data.ndim}," + " for stack of tensors - dimension should be 5 or for a single tensor, dimension should be 4." + ) + + def predict_load_sample(self, sample: torch.Tensor) -> Dict[str, Any]: + return { + DataKeys.INPUT: sample, + "video_index": 0, + } + + class VideoClassificationCSVPredictInput(VideoClassificationDataFramePredictInput): def predict_load_data( self, diff --git a/flash/video/classification/utils.py b/flash/video/classification/utils.py new file mode 100644 index 0000000000..5d51ca216e --- /dev/null +++ b/flash/video/classification/utils.py @@ -0,0 +1,89 @@ +from typing import List, Optional, Tuple, Type + +import torch + +from flash.core.utilities.imports import _VIDEO_AVAILABLE + +if _VIDEO_AVAILABLE: + from pytorchvideo.data.utils import MultiProcessSampler +else: + MultiProcessSampler = None + + +class LabeledVideoTensorDataset(torch.utils.data.IterableDataset): + """LabeledVideoTensorDataset handles a direct tensor input data.""" + + def __init__( + self, + labeled_video_tensors: List[Tuple[str, Optional[dict]]], + video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, + ) -> None: + self._labeled_videos = labeled_video_tensors + + # If a RandomSampler is used we need to pass in a custom random generator that + # ensures all PyTorch multiprocess workers have the same random seed. + self._video_random_generator = None + if video_sampler == torch.utils.data.RandomSampler: + self._video_random_generator = torch.Generator() + self._video_sampler = video_sampler(self._labeled_videos, generator=self._video_random_generator) + else: + self._video_sampler = video_sampler(self._labeled_videos) + + self._video_sampler_iter = None # Initialized on first call to self.__next__() + + # Depending on the clip sampler type, we may want to sample multiple clips + # from one video. In that case, we keep the store video, label and previous sampled + # clip time in these variables. + self._loaded_video_label = None + + def __next__(self) -> dict: + """Retrieves the next clip based on the clip sampling strategy and video sampler. + + Returns: + A dictionary with the following format. + + .. code-block:: text + + { + 'video': , + 'label': , + 'video_label': + 'video_index': , + } + """ + if not self._video_sampler_iter: + # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned. + self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler)) + + # Reuse previously stored video if there are still clips to be sampled from + # the last loaded video. + video_index = next(self._video_sampler_iter) + video_tensor, info_dict = self._labeled_videos[video_index] + self._loaded_video_label = (video_tensor, info_dict, video_index) + + sample_dict = { + "video": self._loaded_video_label[0], + "video_name": f"video{video_index}", + "video_index": video_index, + "label": info_dict, + "video_label": info_dict, + } + + return sample_dict + + def __iter__(self): + self._video_sampler_iter = None # Reset video sampler + + # If we're in a PyTorch DataLoader multiprocessing context, we need to use the + # same seed for each worker's RandomSampler generator. The workers at each + # __iter__ call are created from the unique value: worker_info.seed - worker_info.id, + # which we can use for this seed. + worker_info = torch.utils.data.get_worker_info() + if self._video_random_generator is not None and worker_info is not None: + base_seed = worker_info.seed - worker_info.id + self._video_random_generator.manual_seed(base_seed) + + return self + + def size(self): + return len(self._labeled_videos) diff --git a/tests/video/classification/test_data.py b/tests/video/classification/test_data.py new file mode 100644 index 0000000000..965f2c38a3 --- /dev/null +++ b/tests/video/classification/test_data.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import pytest +import torch + +from flash.core.utilities.imports import _VIDEO_AVAILABLE +from flash.video.classification.data import VideoClassificationData + +if _VIDEO_AVAILABLE: + from pytorchvideo.data.utils import thwc_to_cthw + + +def create_dummy_video_frames(num_frames: int, height: int, width: int): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + return torch.stack(data, 0) + + +def temp_encoded_tensors(num_frames: int, height=10, width=10): + if not _VIDEO_AVAILABLE: + return torch.randint(size=(3, num_frames, height, width), low=0, high=255) + data = create_dummy_video_frames(num_frames, height, width) + return thwc_to_cthw(data).to(torch.float32) + + +def _check_len_and_values(got: list, expected: list): + assert len(got) == len(expected), f"Expected number of labels: {len(expected)}, but got: {len(got)}" + assert got == expected + + +def _check_frames(data, expected_frames_count: Union[list, int]): + if not isinstance(expected_frames_count, list): + expected_frames_count = [expected_frames_count] + + # to be replaced + assert data.size() == len( + expected_frames_count + ), f"Expected: {len(expected_frames_count)} but got {data.size()} samples in the dataset." + for idx, sample_dict in enumerate(data): + sample = sample_dict["video"] + assert ( + sample.shape[1] == expected_frames_count[idx] + ), f"Expected video sample {idx} to have {expected_frames_count[idx]} frames but got {sample.shape[1]} frames" + + +@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.parametrize( + "input_data, input_targets, expected_frames_count", + [ + ([temp_encoded_tensors(5), temp_encoded_tensors(5)], ["label1", "label2"], [5, 5]), + ([temp_encoded_tensors(5), temp_encoded_tensors(10)], ["label1", "label2"], [5, 10]), + (torch.stack((temp_encoded_tensors(5), temp_encoded_tensors(5))), ["label1", "label2"], [5, 5]), + (torch.stack((temp_encoded_tensors(5),)), ["label1"], [5]), + (temp_encoded_tensors(5), ["label1"], [5]), + ], +) +def test_load_data_from_tensors(input_data, input_targets, expected_frames_count): + datamodule = VideoClassificationData.from_tensors(train_data=input_data, train_targets=input_targets, batch_size=1) + _check_len_and_values(got=datamodule.labels, expected=input_targets) + _check_frames(data=datamodule.train_dataset.data, expected_frames_count=expected_frames_count) diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index bf3a15d6ff..1480b663e8 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -28,6 +28,7 @@ from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE, _VIDEO_TESTING from flash.video import VideoClassificationData, VideoClassifier from tests.helpers.task_tester import TaskTester +from tests.video.classification.test_data import create_dummy_video_frames, temp_encoded_tensors if _FIFTYONE_AVAILABLE: import fiftyone as fo @@ -69,17 +70,6 @@ def example_test_sample(self): return self.example_train_sample -def create_dummy_video_frames(num_frames: int, height: int, width: int): - y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) - data = [] - for i in range(num_frames): - xc = float(i) / num_frames - yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 - data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) - return torch.stack(data, 0) - - # https://github.com/facebookresearch/pytorchvideo/blob/4feccb607d7a16933d485495f91d067f177dd8db/tests/utils.py#L33 @contextlib.contextmanager def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None, directory=None): @@ -228,6 +218,57 @@ def test_video_classifier_finetune_from_data_frame(tmpdir): trainer.finetune(model, datamodule=datamodule) +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_video_classifier_finetune_from_tensors(tmpdir): + mock_tensors = temp_encoded_tensors(num_frames=5) + datamodule = VideoClassificationData.from_tensors( + train_data=[mock_tensors, mock_tensors], + train_targets=["Patient", "Doctor"], + video_sampler=SequentialSampler, + batch_size=1, + ) + + for sample in datamodule.train_dataset.data: + expected_t_shape = 5 + assert sample["video"].shape[1] == expected_t_shape + + assert len(datamodule.labels) == 2, f"Expected number of labels to be 2 but found {len(datamodule.labels)}" + + model = VideoClassifier( + num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50", labels=datamodule.labels + ) + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) + trainer.finetune(model, datamodule=datamodule) + + +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_video_classifier_predict_from_tensors(tmpdir): + mock_tensors = temp_encoded_tensors(num_frames=5) + datamodule = VideoClassificationData.from_tensors( + train_data=[mock_tensors, mock_tensors], + train_targets=["Patient", "Doctor"], + predict_data=[mock_tensors, mock_tensors], + video_sampler=SequentialSampler, + batch_size=1, + ) + + for sample in datamodule.train_dataset.data: + expected_t_shape = 5 + assert sample["video"].shape[1] == expected_t_shape + + assert len(datamodule.labels) == 2, f"Expected number of labels to be 2 but found {len(datamodule.labels)}" + + model = VideoClassifier( + num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50", labels=datamodule.labels + ) + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) + trainer.finetune(model, datamodule=datamodule) + predictions = trainer.predict(model, datamodule=datamodule, output="labels") + + assert predictions is not None + assert predictions[0][0] in datamodule.labels + + @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_csv(tmpdir): with mock_video_csv_file(tmpdir) as (mock_csv, total_duration):