diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 403a5fe081..c05cc93dcc 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Sequence +from typing import Any, Dict, List, Sequence, Set from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.utils import _is_overriden from flash.data.callback import BaseDataFetcher @@ -94,14 +95,19 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): """ - def _show(self, running_stage: RunningStage) -> None: - self.show(self.batches[running_stage], running_stage) + def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None: + self.show(self.batches[running_stage], running_stage, func_names_list) - def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None: + def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None: """ Override this function when you want to visualize a composition. """ - for func_name in _PREPROCESS_FUNCS: + # filter out the functions to visualise + func_names_set: Set[str] = set(func_names_list) & set(_PREPROCESS_FUNCS) + if len(func_names_set) == 0: + raise MisconfigurationException(f"Invalid function names: {func_names_list}.") + + for func_name in func_names_set: hook_name = f"show_{func_name}" if _is_overriden(hook_name, self, BaseVisualization): getattr(self, hook_name)(batch[func_name], running_stage) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index fea3e7bbc3..25c3cad03b 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -13,7 +13,7 @@ # limitations under the License. import os import platform -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import pytorch_lightning as pl import torch @@ -149,7 +149,7 @@ def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: setattr(self, iter_name, iterator) return iterator - def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: + def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], reset: bool = True) -> None: """ This function is used to handle transforms profiling for batch visualization. """ @@ -158,6 +158,10 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: if not hasattr(self, iter_name): self._reset_iterator(stage) + # list of functions to visualise + if isinstance(func_names, str): + func_names = [func_names] + iter_dataloader = getattr(self, iter_name) with self.data_fetcher.enable(): try: @@ -166,25 +170,29 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: iter_dataloader = self._reset_iterator(stage) _ = next(iter_dataloader) data_fetcher: BaseVisualization = self.data_fetcher - data_fetcher._show(stage) + data_fetcher._show(stage, func_names) if reset: self.viz.batches[stage] = {} - def show_train_batch(self, reset: bool = True) -> None: + def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" - self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset) + stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] + self._show_batch(stage_name, hooks_names, reset=reset) - def show_val_batch(self, reset: bool = True) -> None: + def show_val_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the validation dataloader.""" - self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset) + stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING] + self._show_batch(stage_name, hooks_names, reset=reset) - def show_test_batch(self, reset: bool = True) -> None: + def show_test_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the test dataloader.""" - self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset) + stage_name: str = _STAGES_PREFIX[RunningStage.TESTING] + self._show_batch(stage_name, hooks_names, reset=reset) - def show_predict_batch(self, reset: bool = True) -> None: + def show_predict_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the predict dataloader.""" - self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset) + stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] + self._show_batch(stage_name, hooks_names, reset=reset) @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: diff --git a/flash/utils/imports.py b/flash/utils/imports.py index 5252a3e3d5..eea3553463 100644 --- a/flash/utils/imports.py +++ b/flash/utils/imports.py @@ -5,4 +5,5 @@ _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") _TORCHVISION_AVAILABLE = _module_available("torchvision") +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _TRANSFORMERS_AVAILABLE = _module_available("transformers") diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index fc017f0350..d277b9416a 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -15,28 +15,37 @@ import pathlib from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +import numpy as np import torch import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset from flash.core.classification import ClassificationState +from flash.core.utils import _is_overriden from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess -from flash.utils.imports import _KORNIA_AVAILABLE +from flash.data.utils import _PREPROCESS_FUNCS +from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE if _KORNIA_AVAILABLE: - import kornia.augmentation as K - import kornia.geometry.transform as T + import kornia as K else: from torchvision import transforms as T +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None + class ImageClassificationPreprocess(Preprocess): @@ -93,9 +102,11 @@ def default_train_transforms(self): # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), + "post_tensor_transform": nn.Sequential( + K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip() + ), "per_batch_transform_on_device": nn.Sequential( - K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: @@ -112,9 +123,9 @@ def default_val_transforms(self): # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), + "post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)), "per_batch_transform_on_device": nn.Sequential( - K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: @@ -159,18 +170,34 @@ def _load_data_dir( dataset: Optional[AutoDataset] = None, ) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]: if isinstance(data, list): + # TODO: define num_classes elsewhere. This is a bad assumption since the list of + # labels might not contain the complete set of ids so that you can infer the total + # number of classes to train in your dataset. dataset.num_classes = len(data) - out = [] + out: List[Tuple[str, int]] = [] for p, label in data: if os.path.isdir(p): - for f in os.listdir(p): + # TODO: there is an issue here when a path is provided along with labels. + # os.listdir cannot assure the same file order as the passed labels list. + files_list: List[str] = os.listdir(p) + if len(files_list) > 1: + raise ValueError( + f"The provided directory contains more than one file." + f"Directory: {p} -> Contains: {files_list}" + ) + for f in files_list: if has_file_allowed_extension(f, IMG_EXTENSIONS): out.append([os.path.join(p, f), label]) - elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS): + elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS): out.append([p, label]) + else: + raise TypeError(f"Unexpected file path type: {p}.") return None, out else: classes, class_to_idx = cls._find_classes(data) + # TODO: define num_classes elsewhere. This is a bad assumption since the list of + # labels might not contain the complete set of ids so that you can infer the total + # number of classes to train in your dataset. dataset.num_classes = len(classes) return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) @@ -318,6 +345,14 @@ def __init__( if self._predict_ds: self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes) + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return MatplotlibVisualization(*args, **kwargs) + @property def num_classes(self) -> int: if self._num_classes is None: @@ -494,3 +529,72 @@ def from_filepaths( seed=seed, **kwargs ) + + +class MatplotlibVisualization(BaseVisualization): + """Process and show the image batch and its associated label using matplotlib. + """ + max_cols: int = 4 # maximum number of columns we accept + block_viz_window: bool = True # parameter to allow user to block visualisation windows + + @staticmethod + def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: + out: np.ndarray + if isinstance(img, Image.Image): + out = np.array(img) + elif isinstance(img, torch.Tensor): + out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() + else: + raise TypeError(f"Unknown image type. Got: {type(img)}.") + return out + + def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): + # define the image grid + cols: int = min(num_samples, self.max_cols) + rows: int = num_samples // cols + + if not _MATPLOTLIB_AVAILABLE: + raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") + + # create figure and set title + fig, axs = plt.subplots(rows, cols) + fig.suptitle(title) + + for i, ax in enumerate(axs.ravel()): + # unpack images and labels + if isinstance(data, list): + _img, _label = data[i] + elif isinstance(data, tuple): + imgs, labels = data + _img, _label = imgs[i], labels[i] + else: + raise TypeError(f"Unknown data type. Got: {type(data)}.") + # convert images to numpy + _img: np.ndarray = self._to_numpy(_img) + if isinstance(_label, torch.Tensor): + _label = _label.squeeze().tolist() + # show image and set label as subplot title + ax.imshow(_img) + ax.set_title(str(_label)) + ax.axis('off') + plt.show(block=self.block_viz_window) + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_load_sample" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_pre_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_to_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_post_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_per_batch_transform(self, batch: List[Any], running_stage): + win_title: str = f"{running_stage} - show_per_batch_transform" + self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) diff --git a/requirements.txt b/requirements.txt index 718b47a08e..f3bab667d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ sentencepiece>=0.1.95 filelock # comes with 3rd-party dependency pycocotools>=2.0.2 ; python_version >= "3.7" kornia>=0.5.0 +matplotlib # used by the visualisation callback diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index ad2e2bfb61..a518c1d96d 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -26,7 +26,7 @@ from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -146,8 +146,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for stage in _STAGES_PREFIX.values(): for _ in range(10): - fcn = getattr(dm, f"show_{stage}_batch") - fcn(reset=False) + for fcn_name in _PREPROCESS_FUNCS: + fcn = getattr(dm, f"show_{stage}_batch") + fcn(fcn_name, reset=False) is_predict = stage == "predict" @@ -206,3 +207,4 @@ def test_data_loaders_num_workers_to_0(tmpdir): assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter) iterator = iter(datamodule.train_dataloader()) assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter) + assert datamodule.num_workers == 3 diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index 24d30bfd8a..3cf6852444 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -15,6 +15,7 @@ from pathlib import Path import numpy as np +import pytest import torch from PIL import Image @@ -31,69 +32,132 @@ def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (_size, _size, 3), dtype="uint8")) -def test_from_filepaths(tmpdir): +def test_from_filepaths_smoke(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") - - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") img_data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_transform=None, - train_labels=[0, 1], + train_filepaths=[tmpdir / "a_1.png", tmpdir / "b_1.png"], + train_labels=[1, 2], batch_size=2, num_workers=0, ) + assert img_data.train_dataloader() is not None + assert img_data.val_dataloader() is None + assert img_data.test_dataloader() is None + data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [1, 2] - assert img_data.val_dataloader() is None - assert img_data.test_dataloader() is None - (tmpdir / "c").mkdir() - (tmpdir / "d").mkdir() - _rand_image().save(tmpdir / "c" / "c_1.png") - _rand_image().save(tmpdir / "c" / "c_2.png") - _rand_image().save(tmpdir / "d" / "d_1.png") - _rand_image().save(tmpdir / "d" / "d_2.png") +def test_from_filepaths_list_image_paths(tmpdir): + tmpdir = Path(tmpdir) (tmpdir / "e").mkdir() - (tmpdir / "f").mkdir() - _rand_image().save(tmpdir / "e" / "e_1.png") - _rand_image().save(tmpdir / "e" / "e_2.png") - _rand_image().save(tmpdir / "f" / "f_1.png") - _rand_image().save(tmpdir / "f" / "f_2.png") + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] img_data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - train_transform=None, - val_filepaths=[tmpdir / "c", tmpdir / "d"], - val_labels=[0, 1], - val_transform=None, - test_transform=None, - test_filepaths=[tmpdir / "e", tmpdir / "f"], - test_labels=[0, 1], - batch_size=1, + train_filepaths=train_images, + train_labels=[0, 3, 6], + val_filepaths=train_images, + val_labels=[1, 4, 7], + test_filepaths=train_images, + test_labels=[2, 5, 8], + batch_size=2, num_workers=0, ) + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here + assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here + + # check validation data data = next(iter(img_data.val_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [1, 4] + # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [2, 5] + + +def test_from_filepaths_visualise(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "b" / "b_1.png") + + dm = ImageClassificationData.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_labels=[0, 1], + val_filepaths=[tmpdir / "b", tmpdir / "a"], + val_labels=[0, 2], + test_filepaths=[tmpdir / "b", tmpdir / "b"], + test_labels=[2, 1], + batch_size=2, + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + + +def test_from_filepaths_visualise_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "b" / "b_1.png") + + dm = ImageClassificationData.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_labels=[[0, 1, 0], [0, 1, 1]], + val_filepaths=[tmpdir / "b", tmpdir / "a"], + val_labels=[[1, 1, 0], [0, 0, 1]], + test_filepaths=[tmpdir / "b", tmpdir / "b"], + test_labels=[[0, 0, 1], [1, 1, 0]], + batch_size=2, + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_val_batch("per_batch_transform") def test_categorical_csv_labels(tmpdir): @@ -143,11 +207,9 @@ def index_col_collate_fn(x): test_labels = labels_from_categorical_csv( test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn ) + B: int = 2 # batch_size data = ImageClassificationData.from_filepaths( - batch_size=2, - train_transform=None, - val_transform=None, - test_transform=None, + batch_size=B, train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), train_labels=train_labels.values(), val_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'), @@ -158,15 +220,18 @@ def index_col_collate_fn(x): for (x, y) in data.train_dataloader(): assert len(x) == 2 + assert sorted(list(y.numpy())) == sorted(list(train_labels.values())[:B]) for (x, y) in data.val_dataloader(): assert len(x) == 2 + assert sorted(list(y.numpy())) == sorted(list(val_labels.values())[:B]) for (x, y) in data.test_dataloader(): assert len(x) == 2 + assert sorted(list(y.numpy())) == sorted(list(test_labels.values())[:B]) -def test_from_folders(tmpdir): +def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -179,6 +244,7 @@ def test_from_folders(tmpdir): _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 196, 196) @@ -187,23 +253,43 @@ def test_from_folders(tmpdir): assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None + +def test_from_folders_train_val(tmpdir): + + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders( train_dir, val_folder=train_dir, test_folder=train_dir, - batch_size=1, + batch_size=2, num_workers=0, ) + data = next(iter(img_data.train_dataloader())) + imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + data = next(iter(img_data.val_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] def test_from_filepaths_multilabel(tmpdir): diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index fb00d93b0f..4bd70455ec 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -35,10 +35,8 @@ def test_classification(tmpdir): (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") + data = ImageClassificationData.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], @@ -46,6 +44,6 @@ def test_classification(tmpdir): num_workers=0, batch_size=2, ) - model = ImageClassifier(2, backbone="resnet18") + model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, datamodule=data, strategy="freeze")