diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml index 54f1210329..581c594b56 100644 --- a/.github/workflows/ci-notebook.yml +++ b/.github/workflows/ci-notebook.yml @@ -56,10 +56,10 @@ jobs: # Look to see if there is a cache hit for the corresponding requirements file key: flash-datasets_predict - #- name: Run Notebooks - # run: | - # jupyter nbconvert --to script flash_notebooks/image_classification.ipynb - # jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb - # - # ipython flash_notebooks/image_classification.py - # ipython flash_notebooks/tabular_classification.py + - name: Run Notebooks + run: | + # jupyter nbconvert --to script flash_notebooks/image_classification.ipynb + jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb + + # ipython flash_notebooks/image_classification.py + ipython flash_notebooks/tabular_classification.py diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 901eadf338..6cb28e97c0 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -60,8 +60,8 @@ jobs: run: | python --version pip --version - pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list shell: bash diff --git a/.gitignore b/.gitignore index 935add8035..6717726144 100644 --- a/.gitignore +++ b/.gitignore @@ -143,7 +143,8 @@ flash_notebooks/*.py flash_notebooks/data MNIST* titanic -coco128 hymenoptera_data -xsum imdb +xsum +coco128 +wmt_en_ro diff --git a/README.md b/README.md index eab38315d7..6c60b5ceb0 100644 --- a/README.md +++ b/README.md @@ -254,8 +254,8 @@ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') datamodule = TabularData.from_csv( "./data/titanic/titanic.csv", test_csv="./data/titanic/test.csv", - categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_input=["Fare"], + cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + num_cols=["Fare"], target="Survived", val_size=0.25, ) diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst index 539df81a83..a97d89dbcf 100644 --- a/docs/source/custom_task.rst +++ b/docs/source/custom_task.rst @@ -67,33 +67,6 @@ for the prediction of diabetes disease progression. We can create this ``DataModule`` below, wrapping the scikit-learn `Diabetes dataset `__. -.. testcode:: - - class DiabetesPipeline(flash.core.data.TaskDataPipeline): - def after_uncollate(self, samples): - return [f"disease progression: {float(s):.2f}" for s in samples] - - class DiabetesData(flash.DataModule): - def __init__(self, batch_size=64, num_workers=0): - x, y = datasets.load_diabetes(return_X_y=True) - x = torch.from_numpy(x).float() - y = torch.from_numpy(y).float().unsqueeze(1) - x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0) - - train_ds = TensorDataset(x_train, y_train) - test_ds = TensorDataset(x_test, y_test) - - super().__init__( - train_ds=train_ds, - test_ds=test_ds, - batch_size=batch_size, - num_workers=num_workers - ) - self.num_inputs = x.shape[1] - - @staticmethod - def default_pipeline(): - return DiabetesPipeline() You’ll notice we added a ``DataPipeline``, which will be used when we call ``.predict()`` on our model. In this case we want to nicely format diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 08bcae266a..e13d97440e 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -7,50 +7,6 @@ Data DataPipeline ------------ -To make tasks work for inference, one must create a ``DataPipeline``. -The ``flash.core.data.DataPipeline`` exposes 6 hooks to override: - -.. code:: python - - class DataPipeline: - """ - This class purpose is to facilitate the conversion of raw data to processed or batched data and back. - Several hooks are provided for maximum flexibility. - - collate_fn: - - before_collate - - collate - - after_collate - - uncollate_fn: - - before_uncollate - - uncollate - - after_uncollate - """ - - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samples - - def collate(self, samples: Any) -> Any: - """Override to convert a set of samples to a batch""" - if not isinstance(samples, Tensor): - return default_collate(samples) - return samples - - def after_collate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def before_uncollate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def uncollate(self, batch: Any) -> ny: - """Override to convert a batch to a set of samples""" - samples = batch - return samples - - def after_uncollate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samplesA +To make tasks work for inference, one must create a ``Preprocess`` and ``PostProcess``. +The ``flash.data.process.Preprocess`` exposes 9 hooks to override which can specifialzed for each stage using +``train``, ``val``, ``test``, ``predict`` prefixes. diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index dc158c1320..45126f20de 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -45,7 +45,7 @@ Use the :class:`~flash.vision.ImageClassifier` pretrained model for inference on print(predictions) # 3b. Or generate predictions with a whole folder! - datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/") + datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) @@ -185,5 +185,3 @@ ImageClassificationData .. automethod:: flash.vision.ImageClassificationData.from_filepaths .. automethod:: flash.vision.ImageClassificationData.from_folders - -.. automethod:: flash.vision.ImageClassificationData.from_folder diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 48d95db288..f2c2b2b36f 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -23,14 +23,14 @@ Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on a .. code-block:: python - from flash.vision import ImageEmbedder + from flash.vision import ImageEmbedder - # Load finetuned task - embedder = ImageEmbedder(backbone="resnet18") + # Load finetuned task + embedder = ImageEmbedder(backbone="resnet18") - # 2. Perform inference on an image file - embeddings = embedder.predict("path/to/image.png") - print(embeddings) + # 2. Perform inference on an image file + embeddings = embedder.predict("path/to/image.png") + print(embeddings) Or on a random image tensor diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 2840923ca0..bed0c9fd53 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -41,14 +41,14 @@ Use the :class:`~flash.vision.ObjectDetector` pretrained model for inference on .. code-block:: python - from flash.vision import ObjectDetector + from flash.vision import ObjectDetector - # 1. Load the model - detector = ObjectDetector() + # 1. Load the model + detector = ObjectDetector() - # 2. Perform inference on an image file - predictions = detector.predict("path/to/image.png") - print(predictions) + # 2. Perform inference on an image file + predictions = detector.predict("path/to/image.png") + print(predictions) Or on a random image tensor diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index 1aab5296ed..8952ffa1eb 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -35,8 +35,8 @@ We can use the Flash Tabular classification task to predict the probability a pa We can create :class:`~flash.tabular.TabularData` from csv files using the :func:`~flash.tabular.TabularData.from_csv` method. We will pass in: * **train_csv**- csv file containing the training data converted to a Pandas DataFrame -* **categorical_input**- a list of the names of columns that contain categorical data (strings or integers) -* **numerical_input**- a list of the names of columns that contain numerical continuous data (floats) +* **cat_cols**- a list of the names of columns that contain categorical data (strings or integers) +* **num_cols**- a list of the names of columns that contain numerical continuous data (floats) * **target**- the name of the column we want to predict @@ -56,8 +56,8 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da datamodule = TabularData.from_csv( "./data/titanic/titanic.csv", test_csv="./data/titanic/test.csv", - categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_input=["Fare"], + cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + num_cols=["Fare"], target="Survived", val_size=0.25, ) @@ -120,8 +120,8 @@ Or you can finetune your own model and use that for prediction: datamodule = TabularData.from_csv( "my_data_file.csv", test_csv="./data/titanic/test.csv", - categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_input=["Fare"], + cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + num_cols=["Fare"], target="Survived", val_size=0.25, ) @@ -166,4 +166,3 @@ TabularData .. automethod:: flash.tabular.TabularData.from_csv .. automethod:: flash.tabular.TabularData.from_df - diff --git a/flash/__init__.py b/flash/__init__.py index bd2f2fd44d..48650b85e6 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -28,12 +28,11 @@ _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) from flash import tabular, text, vision # noqa: E402 -from flash.core import data, utils # noqa: E402 from flash.core.classification import ClassificationTask # noqa: E402 -from flash.core.data import DataModule # noqa: E402 -from flash.core.data.utils import download_data # noqa: E402 from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 +from flash.data.data_module import DataModule # noqa: E402 +from flash.data.utils import download_data # noqa: E402 __all__ = [ "Task", @@ -42,7 +41,5 @@ "vision", "text", "tabular", - "data", - "utils", "download_data", ] diff --git a/flash/core/classification.py b/flash/core/classification.py index f627f8cbf2..86b4066410 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -16,23 +16,18 @@ import torch from torch import Tensor -from flash.core.data import TaskDataPipeline from flash.core.model import Task +from flash.data.process import Postprocess -class ClassificationDataPipeline(TaskDataPipeline): +class ClassificationPostprocess(Postprocess): - def before_uncollate(self, batch: Union[Tensor, tuple]) -> Tensor: - if isinstance(batch, tuple): - batch = batch[0] - return torch.softmax(batch, -1) - - def after_uncollate(self, samples: Any) -> Any: + def per_sample_transform(self, samples: Any) -> Any: return torch.argmax(samples, -1).tolist() class ClassificationTask(Task): - @staticmethod - def default_pipeline() -> ClassificationDataPipeline: - return ClassificationDataPipeline() + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._postprocess = ClassificationPostprocess() diff --git a/flash/core/data/__init__.py b/flash/core/data/__init__.py deleted file mode 100644 index 96aad59678..0000000000 --- a/flash/core/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from flash.core.data.datamodule import DataModule, TaskDataPipeline -from flash.core.data.datapipeline import DataPipeline -from flash.core.data.utils import download_data diff --git a/flash/core/data/datamodule.py b/flash/core/data/datamodule.py deleted file mode 100644 index d32699d2eb..0000000000 --- a/flash/core/data/datamodule.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import platform -from typing import Any, Optional - -import pytorch_lightning as pl -from torch.utils.data import DataLoader, Dataset - -from flash.core.data.datapipeline import DataPipeline - - -class TaskDataPipeline(DataPipeline): - - def after_collate(self, batch: Any) -> Any: - return (batch["x"], batch["target"]) if isinstance(batch, dict) else batch - - -class DataModule(pl.LightningDataModule): - """Basic DataModule class for all Flash tasks - - Args: - train_ds: Dataset for training. Defaults to None. - valid_ds: Dataset for validating model performance during training. Defaults to None. - test_ds: Dataset to test model performance. Defaults to None. - batch_size: the batch size to be used by the DataLoader. Defaults to 1. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - """ - - def __init__( - self, - train_ds: Optional[Dataset] = None, - valid_ds: Optional[Dataset] = None, - test_ds: Optional[Dataset] = None, - batch_size: int = 1, - num_workers: Optional[int] = None, - ): - super().__init__() - self._train_ds = train_ds - self._valid_ds = valid_ds - self._test_ds = test_ds - - if self._train_ds is not None: - self.train_dataloader = self._train_dataloader - - if self._valid_ds is not None: - self.val_dataloader = self._val_dataloader - - if self._test_ds is not None: - self.test_dataloader = self._test_dataloader - - self.batch_size = batch_size - - # TODO: figure out best solution for setting num_workers - if num_workers is None: - if platform.system() == "Darwin": - num_workers = 0 - else: - num_workers = os.cpu_count() - self.num_workers = num_workers - - self._data_pipeline = None - - def _train_dataloader(self) -> DataLoader: - return DataLoader( - self._train_ds, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=self.data_pipeline.collate_fn, - drop_last=True, - ) - - def _val_dataloader(self) -> DataLoader: - return DataLoader( - self._valid_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=self.data_pipeline.collate_fn, - ) - - def _test_dataloader(self) -> DataLoader: - return DataLoader( - self._test_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=self.data_pipeline.collate_fn, - ) - - @property - def data_pipeline(self) -> DataPipeline: - if self._data_pipeline is None: - self._data_pipeline = self.default_pipeline() - return self._data_pipeline - - @data_pipeline.setter - def data_pipeline(self, data_pipeline) -> None: - self._data_pipeline = data_pipeline - - @staticmethod - def default_pipeline() -> DataPipeline: - return TaskDataPipeline() diff --git a/flash/core/data/datapipeline.py b/flash/core/data/datapipeline.py deleted file mode 100644 index 17b91008e9..0000000000 --- a/flash/core/data/datapipeline.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from torch import Tensor -from torch.utils.data._utils.collate import default_collate - - -class DataPipeline: - """ - This class purpose is to facilitate the conversion of raw data to processed or batched data and back. - Several hooks are provided for maximum flexibility. - - Example:: - - .. code-block:: python - - class MyTextDataPipeline(DataPipeline): - def __init__(self, tokenizer, padder): - self.tokenizer = tokenizer - self.padder = padder - - def before_collate(self, samples): - # encode each input sequence - return [self.tokenizer.encode(sample) for sample in samplers] - - def after_collate(self, batch): - # pad tensor elements to the maximum length in the batch - return self.padder(batch) - - def after_uncollate(self, samples): - # decode each input sequence - return [self.tokenizer.decode(sample) for sample in samples] - - """ - - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samples - - def collate(self, samples: Any) -> Any: - """Override to convert a set of samples to a batch""" - if not isinstance(samples, Tensor): - return default_collate(samples) - return samples - - def after_collate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def collate_fn(self, samples: Any) -> Any: - """ - Utility function to convert raw data to batched data - - ``collate_fn`` as used in ``torch.utils.data.DataLoader``. - To avoid the before/after collate transformations, please use ``collate``. - """ - samples = self.before_collate(samples) - batch = self.collate(samples) - batch = self.after_collate(batch) - return batch - - def before_uncollate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def uncollate(self, batch: Any) -> Any: - """Override to convert a batch to a set of samples""" - samples = batch - return samples - - def after_uncollate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samples - - def uncollate_fn(self, batch: Any) -> Any: - """Utility function to convert batched data back to raw data""" - batch = self.before_uncollate(batch) - samples = self.uncollate(batch) - samples = self.after_uncollate(samples) - return samples diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 2285018231..2e477056e5 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -40,7 +40,6 @@ def download_file(url: str, path: str, verbose: bool = False) -> None: if not os.path.exists(local_filename): r = requests.get(url, stream=True) file_size = int(r.headers.get('Content-Length', 0)) - chunk = 1 chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: diff --git a/flash/core/model.py b/flash/core/model.py index e5f2dcef71..6cc7bcda5f 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -13,15 +13,17 @@ # limitations under the License. import functools import os -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch import torchmetrics from pytorch_lightning import LightningModule +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage from torch import nn -from flash.core.data import DataModule, DataPipeline from flash.core.utils import get_callable_dict +from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess def predict_context(func: Callable) -> Callable: @@ -32,13 +34,16 @@ def predict_context(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self, *args, **kwargs) -> Any: + grad_enabled = torch.is_grad_enabled() + is_training = self.training self.eval() torch.set_grad_enabled(False) result = func(self, *args, **kwargs) - self.train() - torch.set_grad_enabled(True) + if is_training: + self.train() + torch.set_grad_enabled(grad_enabled) return result return wrapper @@ -50,7 +55,7 @@ class Task(LightningModule): Args: model: Model to use for the task. loss_fn: Loss function for training - optimizer: Optimizer to use for training, defaults to `torch.optim.SGD`. + optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `5e-5` """ @@ -72,21 +77,23 @@ def __init__( self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") + self._data_pipeline = None + self._preprocess = None + self._postprocess = None def step(self, batch: Any, batch_idx: int) -> Any: """ The training/validation/test step. Override for custom behavior. """ x, y = batch - y_hat = self.forward(x) - output = {"y_hat": self.data_pipeline.before_uncollate(y_hat)} + y_hat = self(x) + output = {"y_hat": y_hat} losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): - output["y_hat"] = self.data_pipeline.before_uncollate(output["y_hat"]) - metric(output["y_hat"], y) + metric(y_hat, y) logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) @@ -119,74 +126,134 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, - batch_idx: Optional[int] = None, - skip_collate_fn: bool = False, - dataloader_idx: Optional[int] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ Predict function for raw data or processed data Args: - x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. - batch_idx: Batch index - - dataloader_idx: Dataloader index - - skip_collate_fn: Whether to skip the collate step. - this is required when passing data already processed - for the model, for example, data from a dataloader - data_pipeline: Use this to override the current data pipeline Returns: The post-processed model predictions - """ - # enable x to be a path to a folder - if isinstance(x, str): - files = os.listdir(x) - files = [os.path.join(x, y) for y in files] - x = files - + running_stage = RunningStage.PREDICTING data_pipeline = data_pipeline or self.data_pipeline - batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None) - predictions = self.predict_step(batch_x, 0) - output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x - return output + x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)] + x = data_pipeline.worker_preprocessor(running_stage)(x) + x = self.transfer_batch_to_device(x, self.device) + x = data_pipeline.device_preprocessor(running_stage)(x) + predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict` + predictions = data_pipeline.postprocessor(predictions) + return predictions + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + if isinstance(batch, tuple): + batch = batch[0] + elif isinstance(batch, list): + # Todo: Understand why stack is needed + batch = torch.stack(batch) + return self(batch) def configure_optimizers(self) -> torch.optim.Optimizer: return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + def configure_finetune_callback(self) -> List[Callback]: + return [] + + @property + def preprocess(self) -> Optional[Preprocess]: + return getattr(self._data_pipeline, '_preprocess_pipeline', None) or self._preprocess + + @preprocess.setter + def preprocess(self, preprocess: Preprocess) -> None: + self._preprocess = preprocess + self.data_pipeline = DataPipeline(preprocess, self.postprocess) + @property - def data_pipeline(self) -> DataPipeline: - # we need to save the pipeline in case this class - # is loaded from checkpoint and used to predict - if not self._data_pipeline: - try: - # datamodule pipeline takes priority - self._data_pipeline = self.trainer.datamodule.data_pipeline - except AttributeError: - self._data_pipeline = self.default_pipeline() + def postprocess(self) -> Postprocess: + return getattr(self._data_pipeline, '_postprocess_pipeline', None) or self._postprocess + + @postprocess.setter + def postprocess(self, postprocess: Postprocess) -> None: + self.data_pipeline = DataPipeline(self.preprocess, postprocess) + self._postprocess = postprocess + + @property + def data_pipeline(self) -> Optional[DataPipeline]: + if self._data_pipeline is not None: + return self._data_pipeline + + elif self.preprocess is not None or self.postprocess is not None: + # use direct attributes here to avoid recursion with properties that also check the data_pipeline property + return DataPipeline(self.preprocess, self.postprocess) + + elif self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: + return self.datamodule.data_pipeline + + elif self.trainer is not None and hasattr( + self.trainer, 'datamodule' + ) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None: + return self.trainer.datamodule.data_pipeline + return self._data_pipeline @data_pipeline.setter - def data_pipeline(self, data_pipeline: DataPipeline) -> None: + def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: self._data_pipeline = data_pipeline - - @staticmethod - def default_pipeline() -> DataPipeline: - """Pipeline to use when there is no datamodule or it has not defined its pipeline""" - return DataModule.default_pipeline() - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.data_pipeline = checkpoint["pipeline"] + if data_pipeline is not None and getattr(data_pipeline, '_preprocess_pipeline', None) is not None: + self._preprocess = data_pipeline._preprocess_pipeline + + if data_pipeline is not None and getattr(data_pipeline, '_postprocess_pipeline', None) is not None: + if type(data_pipeline._postprocess_pipeline) != Postprocess: + self._postprocess = data_pipeline._postprocess_pipeline + + def on_train_dataloader(self) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self, RunningStage.TRAINING) + self.data_pipeline._attach_to_model(self, RunningStage.TRAINING) + super().on_train_dataloader() + + def on_val_dataloader(self) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self, RunningStage.VALIDATING) + self.data_pipeline._attach_to_model(self, RunningStage.VALIDATING) + super().on_val_dataloader() + + def on_test_dataloader(self, *_) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self, RunningStage.TESTING) + self.data_pipeline._attach_to_model(self, RunningStage.TESTING) + super().on_test_dataloader() + + def on_predict_dataloader(self) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self, RunningStage.PREDICTING) + self.data_pipeline._attach_to_model(self, RunningStage.PREDICTING) + super().on_predict_dataloader() + + def on_predict_end(self) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self) + super().on_predict_end() + + def on_fit_end(self) -> None: + if self.data_pipeline is not None: + self.data_pipeline._detach_from_model(self) + super().on_fit_end() def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint["pipeline"] = self.data_pipeline + # TODO: Is this the best way to do this? or should we also use some kind of hparams here? + # This may be an issue since here we create the same problems with pickle as in + # https://pytorch.org/docs/stable/notes/serialization.html - def configure_finetune_callback(self): - return [] + if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: + checkpoint['data_pipeline'] = self.data_pipeline + super().on_save_checkpoint(checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + if 'data_pipeline' in checkpoint: + self.data_pipeline = checkpoint['data_pipeline'] diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index bbaa13fe3e..be6e32038e 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -20,7 +20,7 @@ from torch.utils.data import Dataset from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX +from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -28,8 +28,6 @@ class AutoDataset(Dataset): - FITTING_STAGES = ("train", "val") - STAGES = ("train", "test", "val", "predict") DATASET_KEY = "dataset" """ This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. @@ -89,11 +87,11 @@ def _call_load_sample(self, sample: Any) -> Any: else: return self.load_sample(sample) - def _setup(self, stage: RunningStage) -> None: - assert not stage or _STAGES_PREFIX[stage] in self.STAGES + def _setup(self, stage: Optional[RunningStage]) -> None: + assert not stage or _STAGES_PREFIX[stage] in _STAGES_PREFIX_VALUES previous_load_data = self.load_data.__code__ if self.load_data else None - if (self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage): + if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: self.load_data = getattr( self.data_pipeline._preprocess_pipeline, self.data_pipeline._resolve_function_hierarchy( @@ -128,18 +126,12 @@ def _set_running_stage(self, stage: RunningStage) -> None: def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: - raise RuntimeError( - "Names for LoadSample and LoadData could not be inferred." - " Consider setting the RunningStage" - ) + raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: return self._call_load_sample(self._preprocessed_data[index]) return self._preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: - raise RuntimeError( - "Names for LoadSample and LoadData could not be inferred." - " Consider setting the RunningStage" - ) + raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") return len(self._preprocessed_data) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 4208c6e42c..2db64e457b 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -13,12 +13,10 @@ # limitations under the License. import os import platform -from copy import deepcopy from typing import Any, Callable, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch -from pytorch_lightning.core.datamodule import _DataModuleWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.nn import Module @@ -29,56 +27,20 @@ from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess -# TODO: unused? -class MockLightningModule(pl.LightningModule): - - pass - - -class TaskDataPipeline(DataPipeline): - - def per_batch_transform(self, batch: Any) -> Any: - return (batch["x"], batch.get('target', batch.get('y'))) if isinstance(batch, dict) else batch - - -class _FlashDataModuleWrapper(_DataModuleWrapper): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - """A wrapper for ``DataModule`` that: - - TODO: describe what is __flash_special_attr__ for - """ - __flash_special_attr__ = getattr(self, "__flash_special_attr__", None) - saved_attr = [] - if __flash_special_attr__: - for special_attr_name in __flash_special_attr__: - attr = deepcopy(getattr(self, special_attr_name, None)) - saved_attr.append((special_attr_name, attr)) - - obj = super().__call__(*args, **kwargs) - - if __flash_special_attr__: - for special_attr_name, attr in saved_attr: - setattr(obj, special_attr_name, attr) - - return obj - - -class DataModule(pl.LightningDataModule, metaclass=_FlashDataModuleWrapper): - """Basic DataModule class for all Flash tasks. +class DataModule(pl.LightningDataModule): + """Basic DataModule class for all Flash tasks Args: - train_ds: Dataset for training. - valid_ds: Dataset for validating model performance during training. - test_ds: Dataset to test model performance. - predict_ds: Dataset for predicting. - batch_size: the batch size to be used by the DataLoader. + train_dataset: Dataset for training. Defaults to None. + valid_dataset: Dataset for validating model performance during training. Defaults to None. + test_dataset: Dataset to test model performance. Defaults to None. + predict_dataset: Dataset to predict model performance. Defaults to None. + num_workers: The number of workers to use for parallelized loading. Defaults to None. + predict_ds: Dataset for predicting. Defaults to None. + batch_size: The batch size to be used by the DataLoader. Defaults to 1. num_workers: The number of workers to use for parallelized loading. Defaults to None which equals the number of available CPU threads, - or 0 for MacOS. + or 0 for Darwin platform. """ preprocess_cls = Preprocess @@ -86,19 +48,19 @@ class DataModule(pl.LightningDataModule, metaclass=_FlashDataModuleWrapper): def __init__( self, - train_ds: Optional[Dataset] = None, - valid_ds: Optional[Dataset] = None, - test_ds: Optional[Dataset] = None, - predict_ds: Optional[Dataset] = None, + train_dataset: Optional[Dataset] = None, + valid_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + predict_dataset: Optional[Dataset] = None, batch_size: int = 1, num_workers: Optional[int] = None, ) -> None: super().__init__() - self._train_ds = train_ds - self._valid_ds = valid_ds - self._test_ds = test_ds - self._predict_ds = predict_ds + self._train_ds = train_dataset + self._valid_ds = valid_dataset + self._test_ds = test_dataset + self._predict_ds = predict_dataset if self._train_ds: self.train_dataloader = self._train_dataloader @@ -119,7 +81,6 @@ def __init__( num_workers = 0 if platform.system() == "Darwin" else os.cpu_count() self.num_workers = num_workers - self._data_pipeline = None self._preprocess = None self._postprocess = None @@ -206,11 +167,11 @@ def generate_auto_dataset(self, *args, **kwargs): @property def preprocess(self) -> Preprocess: - return self.preprocess_cls() + return self._preprocess or self.preprocess_cls() @property def postprocess(self) -> Postprocess: - return self.postprocess_cls() + return self._postprocess or self.postprocess_cls() @property def data_pipeline(self) -> DataPipeline: @@ -270,7 +231,7 @@ def train_valid_test_split( number of samples to be contained within test dataset. test_split: If Float, ratio of data to be contained within the test dataset. If Int, number of samples to be contained within test dataset. - seed: Used for the train/val splits when valid_split. + seed: Used for the train/val splits when valid_split is not None. """ n = len(dataset) @@ -335,6 +296,8 @@ def from_load_data_inputs( valid_load_data_input: Optional[Any] = None, test_load_data_input: Optional[Any] = None, predict_load_data_input: Optional[Any] = None, + preprocess: Optional[Preprocess] = None, + postprocess: Optional[Postprocess] = None, **kwargs, ) -> 'DataModule': """ @@ -349,20 +312,32 @@ def from_load_data_inputs( kwargs: Any extra arguments to instantiate the provided ``DataModule`` """ # trick to get data_pipeline from empty DataModule - data_pipeline = cls(**kwargs).data_pipeline - train_ds = cls._generate_dataset_if_possible( + if preprocess or postprocess: + data_pipeline = DataPipeline( + preprocess or cls(**kwargs).preprocess, + postprocess or cls(**kwargs).postprocess, + ) + else: + data_pipeline = cls(**kwargs).data_pipeline + train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) - valid_ds = cls._generate_dataset_if_possible( + valid_dataset = cls._generate_dataset_if_possible( valid_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline ) - test_ds = cls._generate_dataset_if_possible( + test_dataset = cls._generate_dataset_if_possible( test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline ) - predict_ds = cls._generate_dataset_if_possible( + predict_dataset = cls._generate_dataset_if_possible( predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline ) - datamodule = cls(train_ds=train_ds, valid_ds=valid_ds, test_ds=test_ds, predict_ds=predict_ds, **kwargs) + datamodule = cls( + train_dataset=train_dataset, + valid_dataset=valid_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + **kwargs + ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 8130e22c42..b50e468c50 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -57,13 +57,13 @@ def __len__(self): _Sequential: - ┌───────────────────────── + ┌─────────────────────────┐ │ pre_tensor_transform │ - │ | | + │ | │ │ to_tensor_transform │ - │ | | + │ | │ │ post_tensor_transform │ - └────────────────────────── + └─────────────────────────┘ _PreProcessor: @@ -138,8 +138,6 @@ def forward(self, samples: Sequence[Any]): "per_batch_transform_on_device", "collate", } - # TODO: unused? - POSTPROCESS_FUNCS = ("per_batch_transform", "per_sample_transform", "save_data", "save_sample") def __init__(self, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None) -> None: self._preprocess_pipeline = preprocess or Preprocess() @@ -161,6 +159,11 @@ def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optiona return getattr(process_obj, current_method_name).__code__ != getattr(super_obj, method_name).__code__ + @property + def preprocess_state(self): + if self._preprocess_pipeline: + return self._preprocess_pipeline.state + @classmethod def _is_overriden_recursive( cls, method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None @@ -169,6 +172,7 @@ def _is_overriden_recursive( Cropped Version of https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py """ + assert isinstance(process_obj, super_obj) if prefix is None and not hasattr(super_obj, method_name): raise MisconfigurationException(f"This function doesn't belong to the parent class {super_obj}") @@ -197,7 +201,7 @@ def device_preprocessor(self, running_stage: RunningStage) -> _PreProcessor: @property def postprocessor(self) -> _PostProcessor: - return self._postprocessor | self._create_uncollate_postprocessors() + return self._postprocessor or self._create_uncollate_postprocessors() @postprocessor.setter def postprocessor(self, new_processor: _PostProcessor): @@ -211,16 +215,14 @@ def _resolve_function_hierarchy( object_type = Preprocess prefixes = [''] - - # TODO: Check if tuning uses training or validation data if stage in (RunningStage.TRAINING, RunningStage.TUNING): - prefixes = ['train', 'fit'] + prefixes + prefixes += ['train', 'fit'] elif stage == RunningStage.VALIDATING: - prefixes = ['val', 'fit'] + prefixes + prefixes += ['val', 'fit'] elif stage == RunningStage.TESTING: - prefixes = ['test'] + prefixes + prefixes += ['test'] elif stage == RunningStage.PREDICTING: - prefixes = ['predict'] + prefixes + prefixes += ['predict'] for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): @@ -350,13 +352,13 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None setattr(model, final_name, new_loader) def _attach_preprocess_to_model( - self, model: 'Task', stages: Optional[RunningStage] = None, device_transform_only: bool = False + self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False ) -> None: - if not stages: + if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] - elif isinstance(stages, RunningStage): - stages = [stages] + elif isinstance(stage, RunningStage): + stages = [stage] for stage in stages: @@ -439,29 +441,24 @@ def _attach_postprocess_to_model(self, model: 'Task') -> 'Task': ) return model - def _attach_to_model(self, model: 'Task', stages: RunningStage = None): + def _attach_to_model(self, model: 'Task', stage: RunningStage = None): # not necessary to detach. preprocessing and postprocessing for stage will be overwritten. - self._attach_preprocess_to_model(model, stages) + self._attach_preprocess_to_model(model, stage) - if not stages or stages == RunningStage.PREDICTING: + if not stage or stage == RunningStage.PREDICTING: self._attach_postprocess_to_model(model) - def _detach_from_model(self, model: 'Task', stages: Optional[RunningStage] = None): - self._detach_preprocessing_from_model(model, stages) + def _detach_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + self._detach_preprocessing_from_model(model, stage) - if not stages or stages == RunningStage.PREDICTING: + if not stage or stage == RunningStage.PREDICTING: self._detach_postprocess_from_model(model) - @staticmethod - def _composed_collates(samples: Any, worker_collate: Callable, device_collate: Callable) -> Any: - return device_collate(worker_collate(samples)) - - def _detach_preprocessing_from_model(self, model: 'Task', stages: Optional[RunningStage] = None): - if not stages: + def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] - - elif isinstance(stages, RunningStage): - stages = [stages] + elif isinstance(stage, RunningStage): + stages = [stage] for stage in stages: @@ -473,7 +470,7 @@ def _detach_preprocessing_from_model(self, model: 'Task', stages: Optional[Runni if model.transfer_batch_to_device.is_empty(): model.transfer_batch_to_device = model.transfer_batch_to_device.func - if device_collate is None: + if not device_collate: device_collate = self._identity loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' diff --git a/flash/data/process.py b/flash/data/process.py index 9130cb408c..fbeb531fca 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -74,6 +74,14 @@ def validating(self, val: bool) -> None: self._running_stage = None +@dataclass(unsafe_hash=True, frozen=True) +class PreprocessState: + """ + Base class for all preprocess states + """ + pass + + class Preprocess(Properties, torch.nn.Module): def __init__( @@ -89,6 +97,10 @@ def __init__( self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) + @classmethod + def from_state(cls, state: PreprocessState) -> 'Preprocess': + return cls(**vars(state)) + @classmethod def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Any: """Loads entire data from Dataset""" diff --git a/flash/data/utils.py b/flash/data/utils.py index 41c1378046..4be6d177ba 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -29,6 +29,7 @@ RunningStage.VALIDATING: 'val', RunningStage.PREDICTING: 'predict' } +_STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: @@ -83,7 +84,11 @@ def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: return False -class LambdaModule(torch.nn.Module): +class FuncModule(torch.nn.Module): + """ + This class is used to wrap a callable within a nn.Module and + apply the wrapped function in `__call__` + """ def __init__(self, func: Callable) -> None: super().__init__() @@ -101,7 +106,7 @@ def convert_to_modules(transforms: Dict): if transforms is None or isinstance(transforms, torch.nn.Module): return transforms - transforms = apply_to_collection(transforms, Callable, LambdaModule, wrong_dtype=torch.nn.Module) + transforms = apply_to_collection(transforms, Callable, FuncModule, wrong_dtype=torch.nn.Module) transforms = apply_to_collection(transforms, Mapping, torch.nn.ModuleDict, wrong_dtype=torch.nn.ModuleDict) transforms = apply_to_collection( transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict) diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 75ae6dbf4c..6909a29d88 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -11,165 +11,268 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Type, Union import numpy as np import pandas as pd from pandas.core.frame import DataFrame +from pytorch_lightning.utilities.exceptions import MisconfigurationException from sklearn.model_selection import train_test_split -from torch import Tensor -from flash.core.classification import ClassificationDataPipeline -from flash.core.data import DataPipeline -from flash.core.data.datamodule import DataModule -from flash.core.data.utils import _contains_any_tensor +from flash.data.auto_dataset import AutoDataset +from flash.data.data_module import DataModule +from flash.data.process import Preprocess, PreprocessState from flash.tabular.classification.data.dataset import ( _compute_normalization, _dfs_to_samples, _generate_codes, _impute, _pre_transform, + _to_cat_vars_numpy, + _to_num_vars_numpy, PandasDataset, ) -class TabularDataPipeline(ClassificationDataPipeline): +@dataclass(unsafe_hash=True, frozen=True) +class TabularState(PreprocessState): + cat_cols: List[str] # categorical columns used for training + num_cols: List[str] # numerical columns used for training + target_col: str # target column name used for training + mean: DataFrame # mean DataFrame for categorical columsn on train DataFrame + std: DataFrame # std DataFrame for categorical columsn on train DataFrame + codes: Dict # codes for numerical columns used for training + target_codes: Dict # target codes for target used for training + num_classes: int # number of classes used for training + is_regression: bool # whether the task was a is_regression + + +class TabularPreprocess(Preprocess): def __init__( self, - categorical_input: List, - numerical_input: List, - target: str, + cat_cols: List[str], + num_cols: List[str], + target_col: str, mean: DataFrame, std: DataFrame, codes: Dict, + target_codes: Dict, + num_classes: int, + is_regression: bool = False, ): - self._categorical_input = categorical_input - self._numerical_input = numerical_input - self._target = target - self._mean = mean - self._std = std - self._codes = codes - - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - if _contains_any_tensor(samples, dtype=(Tensor, np.ndarray)): - return samples - if isinstance(samples, str): - samples = pd.read_csv(samples) - if isinstance(samples, DataFrame): - samples = [samples] - dfs = _pre_transform( - samples, self._numerical_input, self._categorical_input, self._codes, self._mean, self._std - ) - return _dfs_to_samples(dfs, self._categorical_input, self._numerical_input) - + super().__init__() + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.num_classes = num_classes + self.is_regression = is_regression -class TabularData(DataModule): - """Data module for tabular tasks""" + @property + def state(self) -> TabularState: + return TabularState( + self.cat_cols, self.num_cols, self.target_col, self.mean, self.std, self.codes, self.target_codes, + self.num_classes, self.is_regression + ) - def __init__( - self, + @staticmethod + def generate_state( train_df: DataFrame, - target: str, - categorical_input: Optional[List] = None, - numerical_input: Optional[List] = None, - valid_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, - batch_size: int = 2, - num_workers: Optional[int] = None, + valid_df: Optional[DataFrame], + test_df: Optional[DataFrame], + predict_df: Optional[DataFrame], + target_col: str, + num_cols: List[str], + cat_cols: List[str], + is_regression: bool, + preprocess_state: Optional[TabularState] = None ): - dfs = [train_df] - self._test_df = None + if preprocess_state is not None: + return preprocess_state - if categorical_input is None and numerical_input is None: - raise RuntimeError('Both `categorical_input` and `numerical_input` are None!') + if train_df is None: + raise MisconfigurationException("train_df is required to compute the preprocess state") - categorical_input = categorical_input if categorical_input else [] - numerical_input = numerical_input if numerical_input else [] + dfs = [train_df] if valid_df is not None: - dfs.append(valid_df) + dfs += [valid_df] if test_df is not None: - # save for predict function - self._test_df = test_df.copy() - self._test_df.drop(target, axis=1) - dfs.append(test_df) + dfs += [test_df] + + if predict_df is not None: + dfs += [predict_df] - # impute missing values - dfs = _impute(dfs, numerical_input) + mean, std = _compute_normalization(dfs[0], num_cols) + num_classes = len(dfs[0][target_col].unique()) + if dfs[0][target_col].dtype == object: + # if the target_col is a category, not an int + target_codes = _generate_codes(dfs, [target_col]) + else: + target_codes = None + codes = _generate_codes(dfs, cat_cols) + + return TabularState( + cat_cols, + num_cols, + target_col, + mean, + std, + codes, + target_codes, + num_classes, + is_regression, + ) + def common_load_data(self, df: DataFrame, dataset: AutoDataset): + # impute_data # compute train dataset stats - self.mean, self.std = _compute_normalization(dfs[0], numerical_input) + dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, + self.target_codes) - if dfs[0][target].dtype == object: - # if the target is a category, not an int - self.target_codes = _generate_codes(dfs, [target]) - else: - self.target_codes = None + df = dfs[0] - self.codes = _generate_codes(dfs, categorical_input) + dataset.num_samples = len(df) + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) + num_vars = _to_num_vars_numpy(df, self.num_cols) - dfs = _pre_transform( - dfs, numerical_input, categorical_input, self.codes, self.mean, self.std, target, self.target_codes - ) + cat_vars = np.stack(cat_vars, 1) if len(cat_vars) else np.zeros((len(self), 0)) + num_vars = np.stack(num_vars, 1) if len(num_vars) else np.zeros((len(self), 0)) + return df, cat_vars, num_vars + + def load_data(self, df: DataFrame, dataset: AutoDataset): + df, cat_vars, num_vars = self.common_load_data(df, dataset) + target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) + return [((c, n), t) for c, n, t in zip(cat_vars, num_vars, target)] - # normalize - self.cat_cols = categorical_input - self.num_cols = numerical_input + def predict_load_data(self, sample: Union[str, DataFrame], dataset: AutoDataset): + df = pd.read_csv(sample) if isinstance(sample, str) else sample + _, cat_vars, num_vars = self.common_load_data(df, dataset) + return list(zip(cat_vars, num_vars)) - self._num_classes = len(train_df[target].unique()) - train_ds = PandasDataset(dfs[0], categorical_input, numerical_input, target) - valid_ds = PandasDataset(dfs[1], categorical_input, numerical_input, target) if valid_df is not None else None - test_ds = PandasDataset(dfs[-1], categorical_input, numerical_input, target) if test_df is not None else None - super().__init__(train_ds, valid_ds, test_ds, batch_size=batch_size, num_workers=num_workers) +class TabularData(DataModule): + """Data module for tabular tasks""" + + preprocess_cls = TabularPreprocess + + @property + def preprocess_state(self) -> PreprocessState: + return self._preprocess.state + + @preprocess_state.setter + def preprocess_state(self, preprocess_state): + self._preprocess = self.preprocess_cls.from_state(preprocess_state) + + @property + def codes(self) -> Dict[str, str]: + return self.preprocess_state.codes @property def num_classes(self) -> int: - return self._num_classes + return self.preprocess_state.num_classes + + @property + def cat_cols(self) -> Optional[List[str]]: + return self.preprocess_state.cat_cols + + @property + def num_cols(self) -> Optional[List[str]]: + return self.preprocess_state.num_cols @property def num_features(self) -> int: return len(self.cat_cols) + len(self.num_cols) @classmethod - def from_df( + def from_csv( cls, - train_df: DataFrame, - target: str, - categorical_input: Optional[List] = None, - numerical_input: Optional[List] = None, - valid_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, + target_col: str, + train_csv: Optional[str] = None, + categorical_cols: Optional[List] = None, + numerical_cols: Optional[List] = None, + valid_csv: Optional[str] = None, + test_csv: Optional[str] = None, + predict_csv: Optional[str] = None, batch_size: int = 8, num_workers: Optional[int] = None, - val_size: float = None, - test_size: float = None, + val_size: Optional[float] = None, + test_size: Optional[float] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, + preprocess_state: Optional[TabularState] = None, + **pandas_kwargs, ): - """Creates a TabularData object from pandas DataFrames. + """Creates a TextClassificationData object from pandas DataFrames. Args: - train_df: train data DataFrame - target: The column containing the class id. - categorical_input: The list of categorical columns. - numerical_input: The list of numerical columns. - valid_df: validation data DataFrame - test_df: test data DataFrame - batch_size: the batchsize to use for parallel loading. Defaults to 64. + train_csv: Train data csv file. + target_col: The column containing the class id. + categorical_cols: The list of categorical columns. + numerical_cols: The list of numerical columns. + valid_csv: Validation data csv file. + test_csv: Test data csv file. + batch_size: The batchsize to use for parallel loading. Defaults to 64. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - val_size: float between 0 and 1 to create a validation dataset from train dataset - test_size: float between 0 and 1 to create a test dataset from train validation + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. + val_size: Float between 0 and 1 to create a validation dataset from train dataset. + test_size: Float between 0 and 1 to create a test dataset from train validation. + preprocess_cls: Preprocess class to be used within this DataModule DataPipeline. + preprocess_state: Used to store the train statistics. Returns: TabularData: The constructed data module. Examples:: - text_data = TextClassificationData.from_files("train.csv", label_field="class", text_field="sentence") + text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence") """ + train_df = pd.read_csv(train_csv, **pandas_kwargs) + valid_df = pd.read_csv(valid_csv, **pandas_kwargs) if valid_csv else None + test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None + predict_df = pd.read_csv(predict_csv, **pandas_kwargs) if predict_csv else None + + return cls.from_df( + train_df, + target_col, + categorical_cols, + numerical_cols, + valid_df, + test_df, + predict_df, + batch_size, + num_workers, + val_size, + test_size, + preprocess_state=preprocess_state, + preprocess_cls=preprocess_cls, + ) + + @property + def emb_sizes(self) -> list: + """Recommended embedding sizes.""" + + # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html + # The following "formula" provides a general rule of thumb about the number of embedding dimensions: + # embedding_dimensions = number_of_categories**0.25 + num_classes = [len(self.codes[cat]) for cat in self.cat_cols] + emb_dims = [max(int(n**0.25), 16) for n in num_classes] + return list(zip(num_classes, emb_dims)) + + @staticmethod + def _split_dataframe( + train_df: DataFrame, + valid_df: Optional[DataFrame] = None, + test_df: Optional[DataFrame] = None, + val_size: float = None, + test_size: float = None, + ): if valid_df is None and isinstance(val_size, float) and isinstance(test_size, float): assert 0 < val_size < 1 assert 0 < test_size < 1 @@ -179,81 +282,81 @@ def from_df( assert 0 < test_size < 1 valid_df, test_df = train_test_split(valid_df, test_size=test_size) - datamodule = cls( - train_df=train_df, - target=target, - categorical_input=categorical_input, - numerical_input=numerical_input, - valid_df=valid_df, - test_df=test_df, - batch_size=batch_size, - num_workers=num_workers, - ) - datamodule.data_pipeline = TabularDataPipeline( - categorical_input, numerical_input, target, datamodule.mean, datamodule.std, datamodule.codes - ) + return train_df, valid_df, test_df + + @staticmethod + def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): + if cat_cols is None and num_cols is None: + raise RuntimeError('Both `cat_cols` and `num_cols` are None!') - return datamodule + return cat_cols or [], num_cols or [] @classmethod - def from_csv( + def from_df( cls, - train_csv: str, - target: str, - categorical_input: Optional[List] = None, - numerical_input: Optional[List] = None, - valid_csv: Optional[str] = None, - test_csv: Optional[str] = None, + train_df: DataFrame, + target_col: str, + categorical_cols: Optional[List] = None, + numerical_cols: Optional[List] = None, + valid_df: Optional[DataFrame] = None, + test_df: Optional[DataFrame] = None, + predict_df: Optional[DataFrame] = None, batch_size: int = 8, num_workers: Optional[int] = None, - val_size: Optional[float] = None, - test_size: Optional[float] = None, - **pandas_kwargs, + val_size: float = None, + test_size: float = None, + is_regression: bool = False, + preprocess_state: Optional[TabularState] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, ): - """Creates a TextClassificationData object from pandas DataFrames. + """Creates a TabularData object from pandas DataFrames. Args: - train_csv: train data csv file. - target: The column containing the class id. - categorical_input: The list of categorical columns. - numerical_input: The list of numerical columns. - valid_csv: validation data csv file. - test_csv: test data csv file. - batch_size: the batchsize to use for parallel loading. Defaults to 64. + train_df: Train data DataFrame. + target_col: The column containing the class id. + categorical_cols: The list of categorical columns. + numerical_cols: The list of numerical columns. + valid_df: Validation data DataFrame. + test_df: Test data DataFrame. + batch_size: The batchsize to use for parallel loading. Defaults to 64. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - val_size: float between 0 and 1 to create a validation dataset from train dataset - test_size: float between 0 and 1 to create a test dataset from train validation + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. + val_size: Float between 0 and 1 to create a validation dataset from train dataset. + test_size: Float between 0 and 1 to create a test dataset from train validation. Returns: TabularData: The constructed data module. Examples:: - text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence") + text_data = TextClassificationData.from_files("train.csv", label_field="class", text_field="sentence") """ - train_df = pd.read_csv(train_csv, **pandas_kwargs) - valid_df = pd.read_csv(valid_csv, **pandas_kwargs) if valid_csv else None - test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None - datamodule = cls.from_df( - train_df, target, categorical_input, numerical_input, valid_df, test_df, batch_size, num_workers, val_size, - test_size + categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols) + + train_df, valid_df, test_df = cls._split_dataframe(train_df, valid_df, test_df, val_size, test_size) + + preprocess_cls = preprocess_cls or cls.preprocess_cls + + preprocess_state = preprocess_cls.generate_state( + train_df, + valid_df, + test_df, + predict_df, + target_col, + numerical_cols, + categorical_cols, + is_regression, + preprocess_state=preprocess_state ) - return datamodule - - @property - def emb_sizes(self) -> list: - """Recommended embedding sizes.""" - - # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html - # The following "formula" provides a general rule of thumb about the number of embedding dimensions: - # embedding_dimensions = number_of_categories**0.25 - - num_classes = [len(self.codes[cat]) for cat in self.cat_cols] - emb_dims = [max(int(n**0.25), 16) for n in num_classes] - return list(zip(num_classes, emb_dims)) + preprocess = preprocess_cls.from_state(preprocess_state) - @staticmethod - def default_pipeline() -> DataPipeline(): - # TabularDataPipeline depends on the data - return DataPipeline() + return cls.from_load_data_inputs( + train_load_data_input=train_df, + valid_load_data_input=valid_df, + test_load_data_input=test_df, + predict_load_data_input=predict_df, + batch_size=batch_size, + num_workers=num_workers, + preprocess=preprocess + ) diff --git a/flash/tabular/classification/data/dataset.py b/flash/tabular/classification/data/dataset.py index 415345a048..670816856e 100644 --- a/flash/tabular/classification/data/dataset.py +++ b/flash/tabular/classification/data/dataset.py @@ -20,7 +20,7 @@ from sklearn.model_selection import train_test_split from torch.utils.data import Dataset -from flash.core.data import download_data +from flash.data.utils import download_data def _impute(dfs: List, num_cols: List) -> list: @@ -84,8 +84,8 @@ def _categorize(dfs: List, cat_cols: List, codes: Dict = None) -> list: def _pre_transform( dfs: List, - num_cols: List, - cat_cols: List, + num_cols: List[str], + cat_cols: List[str], codes: Dict, mean: DataFrame, std: DataFrame, @@ -100,32 +100,32 @@ def _pre_transform( return dfs -def _to_cat_vars_numpy(df, cat_cols) -> list: +def _to_cat_vars_numpy(df, cat_cols: List[str]) -> list: if isinstance(df, list) and len(df) == 1: df = df[0] return [c.to_numpy().astype(np.int64) for n, c in df[cat_cols].items()] -def _to_num_cols_numpy(df, num_cols) -> list: +def _to_num_vars_numpy(df, num_cols: List[str]) -> list: if isinstance(df, list) and len(df) == 1: df = df[0] return [c.to_numpy().astype(np.float32) for n, c in df[num_cols].items()] -def _dfs_to_samples(dfs, cat_cols, num_cols) -> list: +def _dfs_to_samples(dfs, cat_cols: List[str], num_cols: List[str]) -> list: num_samples = sum([len(df) for df in dfs]) cat_vars_list = [] num_vars_list = [] for df in dfs: cat_vars = _to_cat_vars_numpy(df, cat_cols) - num_vars = _to_num_cols_numpy(df, num_cols) + num_vars = _to_num_vars_numpy(df, num_cols) cat_vars_list.append(cat_vars) cat_vars_list.append(num_vars_list) # todo: assumes that dfs is not empty cat_vars = np.stack(cat_vars, 1) if len(cat_vars) else np.zeros((num_samples, 0)) num_vars = np.stack(num_vars, 1) if len(num_vars) else np.zeros((num_samples, 0)) - return [(c, n) for c, n in zip(cat_vars, num_vars)] + return list(zip(cat_vars, num_vars)) class PandasDataset(Dataset): @@ -133,19 +133,19 @@ class PandasDataset(Dataset): def __init__( self, df: DataFrame, - cat_cols: List, - num_cols: List, + cat_cols: List[str], + num_cols: List[str], target_col: str, - regression: bool = False, + is_regression: bool = False, predict: bool = False ): self._num_samples = len(df) self.predict = predict cat_vars = _to_cat_vars_numpy(df, cat_cols) - num_vars = _to_num_cols_numpy(df, num_cols) + num_vars = _to_num_vars_numpy(df, num_cols) if not predict: - self.target = df[target_col].to_numpy().astype(np.float32 if regression else np.int64) + self.target = df[target_col].to_numpy().astype(np.float32 if is_regression else np.int64) self.cat_vars = np.stack(cat_vars, 1) if len(cat_vars) else np.zeros((len(self), 0)) self.num_vars = np.stack(num_vars, 1) if len(num_vars) else np.zeros((len(self), 0)) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 7c9738b328..cc9b76e431 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, Callable, List, Tuple, Type import torch -from pytorch_tabnet.tab_network import TabNet from torch.nn import functional as F -from torch.nn.functional import softmax from torchmetrics import Metric from flash.core.classification import ClassificationTask -from flash.core.data import DataPipeline +from flash.utils.imports import _TABNET_AVAILABLE + +if _TABNET_AVAILABLE: + from pytorch_tabnet.tab_network import TabNet class TabularClassifier(ClassificationTask): @@ -69,31 +70,15 @@ def __init__( learning_rate=learning_rate, ) - def predict( - self, - x: Any, - batch_idx: Optional[int] = None, - skip_collate_fn: bool = False, - dataloader_idx: Optional[int] = None, - data_pipeline: Optional[DataPipeline] = None, - ) -> Any: - # override parent predict because forward is called here with the whole batch - data_pipeline = data_pipeline or self.data_pipeline - batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - predictions = self.forward(batch) - return data_pipeline.uncollate_fn(predictions) - - def forward(self, x_in): + def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) x = torch.cat([x for x in x_in if x.numel()], dim=1) - return softmax(self.model(x)[0]) + return F.softmax(self.model(x)[0], -1) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) @classmethod def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier': model = cls(datamodule.num_features, datamodule.num_classes, datamodule.emb_sizes, **kwargs) return model - - @staticmethod - def default_pipeline() -> DataPipeline: - # TabularDataPipeline depends on the data - return DataPipeline() diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index ef9acd81a3..4b24f82424 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -11,121 +11,95 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union -import torch -from datasets import load_dataset -from datasets.utils.download_manager import GenerateMode +from datasets import DatasetDict, load_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput -from flash.core.classification import ClassificationDataPipeline -from flash.core.data import DataModule -from flash.core.data.utils import _contains_any_tensor - - -def tokenize_text_lambda(tokenizer, input, max_length): - return lambda ex: tokenizer( - ex[input], - max_length=max_length, - truncation=True, - padding="max_length", - ) - - -def prepare_dataset( - tokenizer, - train_file, - valid_file, - test_file, - filetype, - input, - max_length, - target=None, - label_to_class_mapping=None, - predict=False, -): - data_files = {} - - if train_file: - data_files["train"] = train_file - if valid_file: - data_files["validation"] = valid_file - if test_file: - data_files["test"] = test_file - - dataset_dict = load_dataset(filetype, data_files=data_files, download_mode=GenerateMode.FORCE_REDOWNLOAD) - - if not predict: - if label_to_class_mapping is None: - label_to_class_mapping = { - v: k - for k, v in enumerate(list(sorted(list(set(dataset_dict["train"][target]))))) - } - - def transform_label(ex): - ex[target] = label_to_class_mapping[ex[target]] - return ex - - # convert labels to ids +from flash.core.classification import ClassificationPostprocess +from flash.data.auto_dataset import AutoDataset +from flash.data.data_module import DataModule +from flash.data.process import Preprocess, PreprocessState - dataset_dict = dataset_dict.map(transform_label) - # tokenize text field - dataset_dict = dataset_dict.map( - tokenize_text_lambda(tokenizer, input, max_length), - batched=True, - ) +@dataclass(unsafe_hash=True, frozen=True) +class TextClassificationState(PreprocessState): + label_to_class_mapping: Dict[str, int] - if target != "labels" and not predict: - dataset_dict.rename_column_(target, "labels") - dataset_dict.set_format("torch", columns=["input_ids"] if predict else ["input_ids", "labels"]) - train_ds = None - valid_ds = None - test_ds = None +class TextClassificationPreprocess(Preprocess): - if "train" in dataset_dict: - train_ds = dataset_dict["train"] - - if "validation" in dataset_dict: - valid_ds = dataset_dict["validation"] + def __init__( + self, + tokenizer: AutoTokenizer, + input: str, + max_length: int, + target: str, + filetype: str, + label_to_class_mapping: Dict[str, int], + ): + """ + This class contains the preprocessing logic for text classification - if "test" in dataset_dict: - test_ds = dataset_dict["test"] + Args: + tokenizer: Hugging Face Tokenizer. + input: The field storing the text to be classified. + max_length: Maximum number of tokens within a single sentence. + target: The field storing the class id of the associated text. + filetype: .csv or .json format type. + label_to_class_mapping: Dictionnary mapping target labels to class indexes. - return train_ds, valid_ds, test_ds, label_to_class_mapping + Returns: + TextClassificationPreprocess: The constructed preprocess objects. + """ -class TextClassificationDataPipeline(ClassificationDataPipeline): + super().__init__() + self.tokenizer = tokenizer + self.input = input + self.filetype = filetype + self.max_length = max_length + self.label_to_class_mapping = label_to_class_mapping + self.target = target - def __init__(self, tokenizer, input: str, max_length: int): - self._tokenizer = tokenizer - self._input = input - self._max_length = max_length self._tokenize_fn = partial( - self._tokenize_fn, tokenizer=self._tokenizer, input=self._input, max_length=self._max_length - ) - - @staticmethod - def _tokenize_fn(ex, tokenizer=None, input: str = None, max_length: int = None) -> Callable: - return tokenizer( - ex[input], - max_length=max_length, + self._tokenize_fn, + tokenizer=self.tokenizer, + input=self.input, + max_length=self.max_length, truncation=True, - padding="max_length", + padding="max_length" ) - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - if _contains_any_tensor(samples): - return samples - elif isinstance(samples, (list, tuple)) and len(samples) > 0 and all(isinstance(s, str) for s in samples): - return [self._tokenize_fn({self._input: s}) for s in samples] - raise MisconfigurationException("samples can only be tensors or a list of sentences.") + @property + def state(self): + return TextClassificationState(self.label_to_class_mapping) + + def per_batch_transform(self, batch: Any) -> Any: + if "labels" not in batch: + # todo: understand why an extra dimension has been added. + if batch["input_ids"].dim() == 3: + batch["input_ids"] = batch["input_ids"].squeeze(0) + return batch + + @staticmethod + def _tokenize_fn( + ex: Union[Dict[str, str], str], + tokenizer=None, + input: str = None, + max_length: int = None, + **kwargs + ) -> Callable: + """This function is used to tokenize sentences using the provided tokenizer.""" + if isinstance(ex, dict): + ex = ex[input] + return tokenizer(ex, max_length=max_length, **kwargs) def collate(self, samples: Any) -> Tensor: """Override to convert a set of samples to a batch""" @@ -133,44 +107,146 @@ def collate(self, samples: Any) -> Tensor: samples = [samples] return default_data_collator(samples) - def after_collate(self, batch: Tensor) -> Tensor: - if "labels" not in batch: - # todo: understand why an extra dimension has been added. - if batch["input_ids"].dim() == 3: - batch["input_ids"] = batch["input_ids"].squeeze(0) - return batch + def _transform_label(self, ex: Dict[str, str]): + ex[self.target] = self.label_to_class_mapping[ex[self.target]] + return ex + + @staticmethod + def generate_state(file: str, target: str, filetype: str) -> TextClassificationState: + data_files = {'train': file} + dataset_dict = load_dataset(filetype, data_files=data_files) + label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(dataset_dict['train'][target])))))} + return TextClassificationState(label_to_class_mapping) + + def load_data( + self, + filepath: str, + dataset: AutoDataset, + columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), + use_full: bool = True + ): + data_files = {} + + stage = dataset.running_stage.value + data_files[stage] = str(filepath) + + # FLASH_TESTING is set in the CI to run faster. + if use_full and os.getenv("FLASH_TESTING", "0") == "0": + dataset_dict = load_dataset(self.filetype, data_files=data_files) + else: + # used for debugging. Avoid processing the entire dataset # noqa E265 + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) + + dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + + # convert labels to ids + if not self.predicting: + dataset_dict = dataset_dict.map(self._transform_label) + + dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) - def before_uncollate(self, batch: Union[Tensor, tuple, SequenceClassifierOutput]) -> Union[tuple, Tensor]: + # Hugging Face models expect target to be named ``labels``. + if not self.predicting and self.target != "labels": + dataset_dict.rename_column_(self.target, "labels") + + dataset_dict.set_format("torch", columns=columns) + + if not self.predicting: + dataset.num_classes = len(self.label_to_class_mapping) + + return dataset_dict[stage] + + def predict_load_data(self, sample: Any, dataset: AutoDataset): + if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): + return self.load_data(sample, dataset, columns=["input_ids", "attention_mask"]) + else: + if isinstance(sample, str): + sample = [sample] + + if isinstance(sample, list) and all(isinstance(s, str) for s in sample): + return [self._tokenize_fn(s) for s in sample] + + else: + raise MisconfigurationException("Currently, we support only list of sentences") + + +class TextClassificationPostProcess(ClassificationPostprocess): + + def per_batch_transform(self, batch: Any) -> Any: if isinstance(batch, SequenceClassifierOutput): batch = batch.logits - return super().before_uncollate(batch) + return super().per_batch_transform(batch) class TextClassificationData(DataModule): - """Data module for text classification tasks.""" + """Data Module for text classification tasks""" + preprocess_cls = TextClassificationPreprocess + postprocess_cls = TextClassificationPostProcess + target: Optional[str] = None - @staticmethod - def default_pipeline(): - return TextClassificationDataPipeline( - AutoTokenizer.from_pretrained("prajjwal1/bert-tiny", use_fast=True), - "sentiment", # Todo: find a way to get the target column name or impose target - 128, + @property + def preprocess_state(self) -> TextClassificationState: + return self._preprocess.state + + @property + def num_classes(self) -> int: + return len(self.preprocess_state.label_to_class_mapping) + + @classmethod + def instantiate_preprocess( + cls, + train_file: Optional[str], + input: str, + target: str, + filetype: str, + backbone: str, + max_length: int, + label_to_class_mapping: Optional[dict] = None, + preprocess_state: Optional[TextClassificationState] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, + ): + if label_to_class_mapping is None: + preprocess_cls = preprocess_cls or cls.preprocess_cls + if train_file is not None: + preprocess_state = preprocess_cls.generate_state(train_file, target, filetype) + else: + if preprocess_state is None: + raise MisconfigurationException( + "Either ``preprocess_state`` or ``train_file`` needs to be provided" + ) + label_to_class_mapping = preprocess_state.label_to_class_mapping + + preprocess_cls = preprocess_cls or cls.preprocess_cls + + return preprocess_cls( + AutoTokenizer.from_pretrained(backbone, use_fast=True), + input, + max_length, + target, + filetype, + label_to_class_mapping, ) @classmethod def from_files( cls, - train_file, - input, - target, + train_file: Optional[str], + input: Optional[str] = 'input', + target: Optional[str] = 'labels', filetype: str = "csv", backbone: str = "prajjwal1/bert-tiny", - valid_file: str = None, - test_file: str = None, + valid_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, max_length: int = 128, + label_to_class_mapping: Optional[dict] = None, batch_size: int = 16, num_workers: Optional[int] = None, - ): + preprocess_state: Optional[TextClassificationState] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, + ) -> 'TextClassificationData': """Creates a TextClassificationData object from files. Args: @@ -178,12 +254,13 @@ def from_files( input: The field storing the text to be classified. target: The field storing the class id of the associated text. filetype: .csv or .json - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. valid_file: Path to validation data. test_file: Path to test data. batch_size: the batchsize to use for parallel loading. Defaults to 64. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: TextClassificationData: The constructed data module. @@ -192,36 +269,32 @@ def from_files( train_df = pd.read_csv("train_data.csv") tab_data = TabularData.from_df(train_df, target="fraud", - numerical_input=["account_value"], - categorical_input=["account_type"]) + num_cols=["account_value"], + cat_cols=["account_type"]) """ - tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - - train_ds, valid_ds, test_ds, label_to_class_mapping = prepare_dataset( - tokenizer=tokenizer, - train_file=train_file, - valid_file=valid_file, - test_file=test_file, - filetype=filetype, - input=input, - max_length=max_length, - target=target, - label_to_class_mapping=None + preprocess = cls.instantiate_preprocess( + train_file, + input, + target, + filetype, + backbone, + max_length, + label_to_class_mapping, + preprocess_state, + preprocess_cls, ) - datamodule = cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, + return cls.from_load_data_inputs( + train_load_data_input=train_file, + valid_load_data_input=valid_file, + test_load_data_input=test_file, + predict_load_data_input=predict_file, batch_size=batch_size, num_workers=num_workers, + preprocess=preprocess ) - datamodule.num_classes = len(label_to_class_mapping) - datamodule.data_pipeline = TextClassificationDataPipeline(tokenizer, input=input, max_length=max_length) - return datamodule - @classmethod def from_file( cls, @@ -230,44 +303,36 @@ def from_file( backbone="bert-base-cased", filetype="csv", max_length: int = 128, + preprocess_state: Optional[TextClassificationState] = None, + label_to_class_mapping: Optional[dict] = None, batch_size: int = 16, num_workers: Optional[int] = None, - ): + ) -> 'TextClassificationData': """Creates a TextClassificationData object from files. Args: - predict_file: Path to prediction data. + + predict_file: Path to training data. input: The field storing the text to be classified. filetype: .csv or .json - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. batch_size: the batchsize to use for parallel loading. Defaults to 64. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - - Returns: - TextClassificationData: The constructed data module. - + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. """ - tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - - _, _, predict_ds, _ = prepare_dataset( - tokenizer=tokenizer, - train_file=None, - valid_file=None, - test_file=predict_file, - filetype=filetype, + return cls.from_files( + None, input=input, + target=None, + filetype=filetype, + backbone=backbone, + valid_file=None, + test_file=None, + predict_file=predict_file, max_length=max_length, - predict=True, - ) - - datamodule = cls( - train_ds=None, - valid_ds=None, - test_ds=predict_ds, + label_to_class_mapping=label_to_class_mapping, batch_size=batch_size, num_workers=num_workers, + preprocess_state=preprocess_state, ) - - datamodule.data_pipeline = TextClassificationDataPipeline(tokenizer, input=input, max_length=max_length) - return datamodule diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 75d85eef9b..40c6843f35 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -18,8 +18,9 @@ import torch from torchmetrics import Accuracy from transformers import BertForSequenceClassification +from transformers.modeling_outputs import SequenceClassifierOutput -from flash.core.classification import ClassificationDataPipeline, ClassificationTask +from flash.core.classification import ClassificationTask from flash.text.classification.data import TextClassificationData @@ -73,10 +74,8 @@ def step(self, batch, batch_idx) -> dict: loss, logits = out[:2] output["loss"] = loss output["y_hat"] = logits - probs = self.data_pipeline.before_uncollate(logits) + if isinstance(logits, SequenceClassifierOutput): + logits = logits.logits + probs = torch.softmax(logits, 1) output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()} return output - - @staticmethod - def default_pipeline() -> ClassificationDataPipeline: - return TextClassificationData.default_pipeline() diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index d539c5ef43..2d9b9e98d6 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -11,96 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union -from datasets import load_dataset +import datasets +import torch +from datasets import DatasetDict, load_dataset +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from transformers import AutoTokenizer, default_data_collator -from flash.core.data import DataModule, TaskDataPipeline +from flash.data.data_module import DataModule +from flash.data.process import Preprocess -def prepare_dataset( - test_file: str, - filetype: str, - pipeline: TaskDataPipeline, - train_file: Optional[str] = None, - valid_file: Optional[str] = None, - predict: bool = False -): - data_files = {} - - if train_file: - data_files["train"] = train_file - if valid_file: - data_files["validation"] = valid_file - if test_file: - data_files["test"] = test_file - - # load the dataset - dataset_dict = load_dataset( - filetype, - data_files=data_files, - ) - - # tokenize the dataset - dataset_dict = dataset_dict.map( - pipeline._tokenize_fn, - batched=True, - ) - columns = ["input_ids", "attention_mask"] if predict else ["input_ids", "attention_mask", "labels"] - dataset_dict.set_format(columns=columns) - - train_ds = None - valid_ds = None - test_ds = None - - if "train" in dataset_dict: - train_ds = dataset_dict["train"] - - if "validation" in dataset_dict: - valid_ds = dataset_dict["validation"] - - if "test" in dataset_dict: - test_ds = dataset_dict["test"] - - return train_ds, valid_ds, test_ds - - -class Seq2SeqDataPipeline(TaskDataPipeline): +class Seq2SeqPreprocess(Preprocess): def __init__( self, tokenizer, input: str, + filetype: str, target: Optional[str] = None, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'longest' ): + super().__init__() + self.tokenizer = tokenizer - self._input = input - self._target = target - self._max_target_length = max_target_length - self._max_source_length = max_source_length - self._padding = padding + self.input = input + self.filetype = filetype + self.target = target + self.max_target_length = max_target_length + self.max_source_length = max_source_length + self.padding = padding self._tokenize_fn = partial( self._tokenize_fn, tokenizer=self.tokenizer, - input=self._input, - target=self._target, - max_source_length=self._max_source_length, - max_target_length=self._max_target_length, - padding=self._padding, + input=self.input, + target=self.target, + max_source_length=self.max_source_length, + max_target_length=self.max_target_length, + padding=self.padding ) - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - if isinstance(samples, (list, tuple)) and len(samples) > 0 and all(isinstance(s, str) for s in samples): - return [self._tokenize_fn({self._input: s, self._target: None}) for s in samples] - return samples - @staticmethod def _tokenize_fn( ex, @@ -120,99 +76,155 @@ def _tokenize_fn( ) return output + def load_data( + self, + file: str, + use_full: bool = True, + columns: List[str] = ["input_ids", "attention_mask", "labels"] + ) -> 'datasets.Dataset': + data_files = {} + stage = self._running_stage.value + data_files[stage] = str(file) + + # FLASH_TESTING is set in the CI to run faster. + if use_full and os.getenv("FLASH_TESTING", "0") == "0": + dataset_dict = load_dataset(self.filetype, data_files=data_files) + else: + # used for debugging. Avoid processing the entire dataset # noqa E265 + try: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) + except AssertionError: + dataset_dict = load_dataset(self.filetype, data_files=data_files) + + dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + dataset_dict.set_format(columns=columns) + return dataset_dict[stage] + + def predict_load_data(self, sample: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: + if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): + return self.load_data(sample, use_full=True, columns=["input_ids", "attention_mask"]) + else: + if isinstance(sample, (list, tuple)) and len(sample) > 0 and all(isinstance(s, str) for s in sample): + return [self._tokenize_fn({self.input: s, self.target: None}) for s in sample] + else: + raise MisconfigurationException("Currently, we support only list of sentences") + def collate(self, samples: Any) -> Tensor: """Override to convert a set of samples to a batch""" return default_data_collator(samples) - def after_collate(self, batch: Any) -> Any: - return batch - - def uncollate(self, generated_tokens: Any) -> Any: - pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - pred_str = [str.strip(s) for s in pred_str] - return pred_str - class Seq2SeqData(DataModule): """Data module for Seq2Seq tasks.""" - @staticmethod - def default_pipeline(): - return Seq2SeqDataPipeline( - AutoTokenizer.from_pretrained("sshleifer/tiny-mbart", use_fast=True), - input="input", + preprocess_cls = Seq2SeqPreprocess + + @classmethod + def instantiate_preprocess( + cls, + tokenizer: AutoTokenizer, + input: str, + filetype: str, + target: str, + max_source_length: int, + max_target_length: int, + padding: int, + preprocess_cls: Optional[Type[Preprocess]] = None + ) -> Preprocess: + """ + This function is used to instantiate the ``Seq2SeqPreprocess`` preprocess. + + Args: + tokenizer: Path to training data. + input: The field storing the source translation text. + filetype: ``csv`` or ``json`` File + target: The field storing the target translation text. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + preprocess_cls: Preprocess cls + """ + + preprocess_cls = preprocess_cls or cls.preprocess_cls + + return preprocess_cls( + tokenizer=tokenizer, + input=input, + filetype=filetype, + target=target, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, ) @classmethod def from_files( cls, - train_file: str, + train_file: Optional[str], input: str = 'input', target: Optional[str] = None, filetype: str = "csv", backbone: str = "sshleifer/tiny-mbart", valid_file: Optional[str] = None, test_file: Optional[str] = None, + predict_file: Optional[str] = None, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', batch_size: int = 32, num_workers: Optional[int] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, ): """Creates a Seq2SeqData object from files. - Args: train_file: Path to training data. input: The field storing the source translation text. target: The field storing the target translation text. - filetype: .csv or .json - backbone: tokenizer to use, can use any HuggingFace tokenizer. + filetype: ``csv`` or ``json`` File + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. valid_file: Path to validation data. test_file: Path to test data. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 32. + batch_size: The batchsize to use for parallel loading. Defaults to 32. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: Seq2SeqData: The constructed data module. - Examples:: - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - numerical_input=["account_value"], - categorical_input=["account_type"]) - + tab_data = TabularData.from_df(train_df, + target="fraud", + num_cols=["account_value"], + cat_cols=["account_type"]) """ tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - - pipeline = Seq2SeqDataPipeline( - tokenizer=tokenizer, - input=input, - target=target, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding - ) - - train_ds, valid_ds, test_ds = prepare_dataset( - train_file=train_file, valid_file=valid_file, test_file=test_file, filetype=filetype, pipeline=pipeline + preprocess = cls.instantiate_preprocess( + tokenizer, + input, + filetype, + target, + max_source_length, + max_target_length, + padding, + preprocess_cls=preprocess_cls ) - datamodule = cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, + return cls.from_load_data_inputs( + train_load_data_input=train_file, + valid_load_data_input=valid_file, + test_load_data_input=test_file, + predict_load_data_input=predict_file, batch_size=batch_size, num_workers=num_workers, + preprocess=preprocess ) - datamodule.data_pipeline = pipeline - return datamodule - @classmethod def from_file( cls, @@ -226,48 +238,36 @@ def from_file( padding: Union[str, bool] = 'max_length', batch_size: int = 32, num_workers: Optional[int] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, ): """Creates a TextClassificationData object from files. - Args: predict_file: Path to prediction input file. input: The field storing the source translation text. target: The field storing the target translation text. - backbone: tokenizer to use, can use any HuggingFace tokenizer. - filetype: csv or json. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. + filetype: Csv or json. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 32. + batch_size: The batchsize to use for parallel loading. Defaults to 32. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. - + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: Seq2SeqData: The constructed data module. - """ - tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - - pipeline = Seq2SeqDataPipeline( - tokenizer=tokenizer, + return cls.from_files( + train_file=None, input=input, target=target, + filetype=filetype, + backbone=backbone, + predict_file=predict_file, max_source_length=max_source_length, max_target_length=max_target_length, - padding=padding - ) - - train_ds, valid_ds, test_ds = prepare_dataset( - test_file=predict_file, filetype=filetype, pipeline=pipeline, predict=True - ) - - datamodule = cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, + padding=padding, batch_size=batch_size, num_workers=num_workers, + preprocess_cls=preprocess_cls, ) - - datamodule.data_pipeline = pipeline - return datamodule diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 97b0935173..8971584bde 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -90,8 +90,8 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor: self.log("train_loss", loss) return loss - def common_step(self, prefix: str, batch: Any) -> Tensor: - generated_tokens = self.predict(batch, skip_collate_fn=True) + def common_step(self, prefix: str, batch: Any) -> torch.Tensor: + generated_tokens = self(batch) self.compute_metrics(generated_tokens, batch, prefix) def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): @@ -108,7 +108,7 @@ def task(self) -> Optional[str]: """ Override to define AutoConfig task specific parameters stored within the model. """ - pass + return def _initialize_model_specific_parameters(self): task_specific_params = self.model.config.task_specific_params @@ -120,7 +120,7 @@ def _initialize_model_specific_parameters(self): @property def tokenizer(self) -> PreTrainedTokenizerBase: - return self.data_pipeline.tokenizer + return self.data_pipeline._preprocess_pipeline.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True) diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 20e0eb2ba2..ba9b93b6e0 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -11,36 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Optional, Type, Union + from transformers import AutoTokenizer -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqDataPipeline +from flash.data.process import Postprocess, Preprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess + + +class SummarizationPostprocess(Postprocess): + + def __init__( + self, + tokenizer: AutoTokenizer, + ): + super().__init__() + self.tokenizer = tokenizer + + def uncollate(self, generated_tokens: Any) -> Any: + pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + pred_str = [str.strip(s) for s in pred_str] + return pred_str class SummarizationData(Seq2SeqData): - from typing import Optional, Union - @staticmethod - def default_pipeline(): - return Seq2SeqDataPipeline( - AutoTokenizer.from_pretrained("t5-small", use_fast=True), - input="input", - ) + preprocess_cls = Seq2SeqPreprocess + postprocess_cls = SummarizationPostprocess + + @classmethod + def instantiate_postprocess( + cls, tokenizer: AutoTokenizer, postprocess_cls: Optional[Type[Postprocess]] = None + ) -> Postprocess: + postprocess_cls = postprocess_cls or cls.postprocess_cls + return postprocess_cls(tokenizer) @classmethod def from_files( cls, - train_file: str, + train_file: Optional[str] = None, input: str = 'input', target: Optional[str] = None, filetype: str = "csv", backbone: str = "t5-small", valid_file: str = None, test_file: str = None, + predict_file: str = None, max_source_length: int = 512, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', batch_size: int = 16, num_workers: Optional[int] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, + postprocess_cls: Optional[Type[Postprocess]] = None, ): """Creates a SummarizationData object from files. @@ -49,15 +72,16 @@ def from_files( input: The field storing the source translation text. target: The field storing the target translation text. filetype: .csv or .json - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. valid_file: Path to validation data. test_file: Path to test data. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 16. + batch_size: The batchsize to use for parallel loading. Defaults to 16. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: SummarizationData: The constructed data module. @@ -66,23 +90,33 @@ def from_files( train_df = pd.read_csv("train_data.csv") tab_data = TabularData.from_df(train_df, target="fraud", - numerical_input=["account_value"], - categorical_input=["account_type"]) + num_cols=["account_value"], + cat_cols=["account_type"]) """ - return super().from_files( - train_file=train_file, - valid_file=valid_file, - test_file=test_file, - input=input, - target=target, - backbone=backbone, - filetype=filetype, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, + tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + preprocess = cls.instantiate_preprocess( + tokenizer, + input, + filetype, + target, + max_source_length, + max_target_length, + padding, + preprocess_cls=preprocess_cls + ) + + postprocess = cls.instantiate_postprocess(tokenizer, postprocess_cls=postprocess_cls) + + return cls.from_load_data_inputs( + train_load_data_input=train_file, + valid_load_data_input=valid_file, + test_load_data_input=test_file, + predict_load_data_input=predict_file, batch_size=batch_size, - num_workers=num_workers + num_workers=num_workers, + preprocess=preprocess, + postprocess=postprocess, ) @classmethod @@ -105,14 +139,15 @@ def from_file( predict_file: Path to prediction input file. input: The field storing the source translation text. target: The field storing the target translation text. - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. filetype: csv or json. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 16. + batch_size: The batchsize to use for parallel loading. Defaults to 16. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: SummarizationData: The constructed data module. diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index a044423253..a3c9142bb5 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch @@ -66,7 +66,7 @@ def __init__( def task(self) -> str: return "summarization" - def compute_metrics(self, generated_tokens, batch, prefix): + def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None: tgt_lns = self.tokenize_labels(batch["labels"]) - result = self.rouge(generated_tokens, tgt_lns) + result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns) self.log_dict(result, on_step=False, on_epoch=True) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index afaf9b5cfb..92096b431a 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -11,23 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Type, Union -from transformers import AutoTokenizer - -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqDataPipeline +from flash.data.process import Preprocess +from flash.text.seq2seq.core.data import Seq2SeqData class TranslationData(Seq2SeqData): """Data module for Translation tasks.""" - @staticmethod - def default_pipeline(): - return Seq2SeqDataPipeline( - AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro", use_fast=True), - input="input", - ) - @classmethod def from_files( cls, @@ -38,11 +30,13 @@ def from_files( backbone="facebook/mbart-large-en-ro", valid_file=None, test_file=None, + predict_file=None, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', batch_size: int = 8, num_workers: Optional[int] = None, + preprocess_cls: Optional[Type[Preprocess]] = None ): """Creates a TranslateData object from files. @@ -51,15 +45,17 @@ def from_files( input: The field storing the source translation text. target: The field storing the target translation text. filetype: .csv or .json - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. valid_file: Path to validation data. test_file: Path to test data. + predict_file: Path to predict data. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 8. + batch_size: The batchsize to use for parallel loading. Defaults to 8. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to None which equals the number of available CPU threads, + or 0 for Darwin platform. Returns: TranslateData: The constructed data module. @@ -68,14 +64,15 @@ def from_files( train_df = pd.read_csv("train_data.csv") tab_data = TabularData.from_df(train_df, target="fraud", - numerical_input=["account_value"], - categorical_input=["account_type"]) + num_cols=["account_value"], + cat_cols=["account_type"]) """ return super().from_files( train_file=train_file, valid_file=valid_file, test_file=test_file, + predict_file=predict_file, input=input, target=target, backbone=backbone, @@ -84,7 +81,8 @@ def from_files( max_target_length=max_target_length, padding=padding, batch_size=batch_size, - num_workers=num_workers + num_workers=num_workers, + preprocess_cls=preprocess_cls ) @classmethod @@ -107,14 +105,15 @@ def from_file( predict_file: Path to prediction input file. input: The field storing the source translation text. target: The field storing the target translation text. - backbone: tokenizer to use, can use any HuggingFace tokenizer. + backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. filetype: csv or json. max_source_length: Maximum length of the source text. Any text longer will be truncated. max_target_length: Maximum length of the target text. Any text longer will be truncated. padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: the batchsize to use for parallel loading. Defaults to 8. + batch_size: The batchsize to use for parallel loading. Defaults to 8. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to None which equals the number of available CPU threads, + Returns: Seq2SeqData: The constructed data module. diff --git a/flash/utils/__init__.py b/flash/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/utils/imports.py b/flash/utils/imports.py new file mode 100644 index 0000000000..b0ddfa96a3 --- /dev/null +++ b/flash/utils/imports.py @@ -0,0 +1,5 @@ +from pytorch_lightning.utilities.imports import _module_available + +_TABNET_AVAILABLE = _module_available("pytorch_tabnet") +_KORNIA_AVAILABLE = _module_available("kornia") +_COCO_AVAILABLE = _module_available("pycocotools") diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 50b3ddd54e..35bb54e60f 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -13,369 +13,404 @@ # limitations under the License. import os import pathlib -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import torch -from PIL import Image, UnidentifiedImageError +from PIL import Image +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torchvision import transforms as T -from torchvision.datasets import VisionDataset +from torch import nn +from torch.nn.modules import ModuleDict +from torch.utils.data import Dataset +from torch.utils.data._utils.collate import default_collate +from torchvision import transforms as torchvision_T from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset +from torchvision.transforms.functional import to_pil_image -from flash.core.classification import ClassificationDataPipeline -from flash.core.data.datamodule import DataModule -from flash.core.data.utils import _contains_any_tensor +from flash.data.auto_dataset import AutoDataset +from flash.data.data_module import DataModule +from flash.data.data_pipeline import DataPipeline +from flash.data.process import Preprocess +from flash.utils.imports import _KORNIA_AVAILABLE +if _KORNIA_AVAILABLE: + import kornia.augmentation as K + import kornia.geometry.transform as T +else: + from torchvision import transforms as T -def _pil_loader(path) -> Image: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, "rb") as f, Image.open(f) as img: - return img.convert("RGB") - -class FilepathDataset(torch.utils.data.Dataset): - """Dataset that takes in filepaths and labels.""" - - def __init__( - self, - filepaths: Optional[Sequence[Union[str, pathlib.Path]]], - labels: Optional[Sequence], - loader: Callable, - transform: Optional[Callable] = None, - ): - """ - Args: - filepaths: file paths to load with :attr:`loader` - labels: the labels corresponding to the :attr:`filepaths`. - Each unique value will get a class index by sorting them. - loader: the function to load an image from a given file path - transform: the transforms to apply to the loaded images - """ - self.fnames = filepaths or [] - self.labels = labels or [] - self.transform = transform - self.loader = loader - if not self.has_dict_labels and self.has_labels: - self.label_to_class_mapping = dict(map(reversed, enumerate(sorted(set(self.labels))))) - - @property - def has_dict_labels(self) -> bool: - return isinstance(self.labels, dict) - - @property - def has_labels(self) -> bool: - return self.labels - - def __len__(self) -> int: - return len(self.fnames) - - def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]: - filename = self.fnames[index] - img = self.loader(filename) - if self.transform: - img = self.transform(img) - label = None - if self.has_dict_labels: - name = os.path.splitext(filename)[0] - name = os.path.basename(name) - label = self.labels[name] - - elif self.has_labels: - label = self.labels[index] - label = self.label_to_class_mapping[label] - return img, label - - -class FlashDatasetFolder(VisionDataset): - """A generic data loader where the samples are arranged in this way: :: - - root/class_x/xxx.ext - root/class_x/xxy.ext - root/class_x/xxz.ext - - root/class_y/123.ext - root/class_y/nsdf3.ext - root/class_y/asd932_.ext - - Args: - root: Root directory path. - loader: A function to load a sample given its path. - extensions: A list of allowed extensions. both extensions - and is_valid_file should not be passed. - transform: A function/transform that takes in - a sample and returns a transformed version. - E.g, ``transforms.RandomCrop`` for images. - target_transform: A function/transform that takes - in the target and transforms it. - is_valid_file: A function that takes path of a file - and check if the file is a valid file (used to check of corrupt files) - both extensions and is_valid_file should not be passed. - with_targets: Whether to include targets - img_paths: List of image paths to load. Only used when ``with_targets=False`` - - Attributes: - classes (list): List of the class names sorted alphabetically. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample path, class_index) tuples - targets (list): The class_index value for each image in the dataset - """ - - def __init__( - self, - root: str, - loader: Callable, - extensions: Tuple[str] = IMG_EXTENSIONS, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - is_valid_file: Optional[Callable] = None, - with_targets: bool = True, - img_paths: Optional[List[str]] = None, - ): - super(FlashDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) - self.loader = loader - self.extensions = extensions - self.with_targets = with_targets - - if with_targets: - classes, class_to_idx = self._find_classes(self.root) - samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) - - if len(samples) == 0: - msg = "Found 0 files in subfolders of: {}\n".format(self.root) - if extensions: - msg += "Supported extensions are: {}".format(",".join(extensions)) - raise RuntimeError(msg) - - self.classes = classes - self.class_to_idx = class_to_idx - self.samples = samples - self.targets = [s[1] for s in samples] - else: - if not img_paths: - raise MisconfigurationException( - "`FlashDatasetFolder(with_target=False)` but no `img_paths` were provided" - ) - self.samples = img_paths +class ImageClassificationPreprocess(Preprocess): + to_tensor = torchvision_T.ToTensor() @staticmethod - def _find_classes(folder: str): + def _find_classes(dir: str) -> Tuple: """ Finds the class folders in a dataset. - Args: - folder: Root directory path. - + dir: Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - Ensures: No class is a subdirectory of another. """ - classes = [d.name for d in os.scandir(folder) if d.is_dir()] + classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx - def __getitem__(self, index): - """ - Args: - index (int): Index + @staticmethod + def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]: + files = [] + if isinstance(samples, str): + samples = [samples] - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - if self.with_targets: - path, target = self.samples[index] - if self.target_transform: - target = self.target_transform(target) - else: - path = self.samples[index] - sample = self.loader(path) - if self.transform: - sample = self.transform(sample) - return (sample, target) if self.with_targets else sample + if isinstance(samples, (list, tuple)) and all(os.path.isdir(s) for s in samples): + files = [os.path.join(sp, f) for sp in samples for f in os.listdir(sp)] - def __len__(self) -> int: - return len(self.samples) + elif isinstance(samples, (list, tuple)) and all(os.path.isfile(s) for s in samples): + files = samples + files = list(filter(lambda p: has_file_allowed_extension(p, IMG_EXTENSIONS), files)) -_DEFAULT_TRAIN_TRANSFORMS = T.Compose([ - T.RandomResizedCrop(224), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), -]) + return files -_DEFAULT_VALID_TRANSFORMS = T.Compose([ - T.Resize(256), - T.CenterCrop(224), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), -]) + @classmethod + def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[str]: + if isinstance(data, list): + dataset.num_classes = len(data) + out = [] + for p, label in data: + if os.path.isdir(p): + for f in os.listdir(p): + if has_file_allowed_extension(f, IMG_EXTENSIONS): + out.append([os.path.join(p, f), label]) + elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS): + out.append([p, label]) + return out + else: + classes, class_to_idx = cls._find_classes(data) + dataset.num_classes = len(classes) + return make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) -# todo: torch.nn.modules.module.ModuleAttributeError: 'Resize' object has no attribute '_forward_hooks' -# Find better fix and raised issue on torchvision. -_DEFAULT_VALID_TRANSFORMS.transforms[0]._forward_hooks = {} + @classmethod + def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any: + _classes = [tmp[1] for tmp in data] + _classes = torch.stack([ + torch.tensor(int(_cls)) if not isinstance(_cls, torch.Tensor) else _cls.view(-1) for _cls in _classes + ]).unique() -class ImageClassificationDataPipeline(ClassificationDataPipeline): + dataset.num_classes = len(_classes) - def __init__( - self, - train_transform: Optional[Callable] = _DEFAULT_TRAIN_TRANSFORMS, - valid_transform: Optional[Callable] = _DEFAULT_VALID_TRANSFORMS, - use_valid_transform: bool = True, - loader: Callable = _pil_loader - ): - self._train_transform = train_transform - self._valid_transform = valid_transform - self._use_valid_transform = use_valid_transform - self._loader = loader + return data - def before_collate(self, samples: Any) -> Any: - if _contains_any_tensor(samples): - return samples + @classmethod + def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: + if isinstance(data, (str, pathlib.Path, list)): + return cls._load_data_dir(data=data, dataset=dataset) + return cls._load_data_files_labels(data=data, dataset=dataset) - if isinstance(samples, str): - samples = [samples] - if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): - outputs = [] - for sample in samples: - try: - output = self._loader(sample) - transform = self._valid_transform if self._use_valid_transform else self._train_transform - outputs.append(transform(output)) - except UnidentifiedImageError: - print(f'Skipping: could not read file {sample}') + @staticmethod + def load_sample(sample) -> Union[Image.Image, list]: + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + if isinstance(sample, torch.Tensor): + return sample + + if isinstance(sample, (tuple, list)): + path = sample[0] + sample = list(sample) + else: + path = sample + + with open(path, "rb") as f, Image.open(f) as img: + img = img.convert("RGB") - return outputs - raise MisconfigurationException("The samples should either be a tensor or a list of paths.") + if isinstance(sample, list): + sample[0] = img + return sample + + return img + + @classmethod + def predict_load_data(cls, samples: Any) -> Iterable: + if isinstance(samples, torch.Tensor): + return samples + return cls._get_predicting_files(samples) + + def _convert_tensor_to_pil(self, sample): + # some datasets provide their data as tensors. + # however, it would be better to convert those data once in load_data + if isinstance(sample, torch.Tensor): + sample = to_pil_image(sample) + return sample + + def _apply_transform( + self, sample: Any, transform: Union[Callable, Dict[str, Callable]], func_name: str + ) -> torch.Tensor: + if transform is not None: + if isinstance(transform, (Dict, ModuleDict)): + if func_name not in transform: + return sample + else: + transform = transform[func_name] + sample = transform(sample) + return sample + + def collate(self, samples: Sequence) -> Any: + _samples = [] + # todo: Kornia transforms add batch dimension which need to be removed + for sample in samples: + if isinstance(sample, tuple): + sample = (sample[0].squeeze(0), ) + sample[1:] + else: + sample = sample.squeeze(0) + _samples.append(sample) + return default_collate(_samples) + + def common_pre_tensor_transform(self, sample: Any, transform) -> Any: + return self._apply_transform(sample, transform, "pre_tensor_transform") + + def train_pre_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_pre_tensor_transform(source, self.train_transform), target + + def val_pre_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_pre_tensor_transform(source, self.valid_transform), target + + def test_pre_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_pre_tensor_transform(source, self.test_transform), target + + def predict_pre_tensor_transform(self, sample: Any) -> Any: + if isinstance(sample, torch.Tensor): + return sample + return self.common_pre_tensor_transform(sample, self.predict_transform) + + def to_tensor_transform(self, sample) -> Any: + source, target = sample + return source if isinstance(source, torch.Tensor) else self.to_tensor(source), target + + def predict_to_tensor_transform(self, sample) -> Any: + if isinstance(sample, torch.Tensor): + return sample + return self.to_tensor(sample) + + def common_post_tensor_transform(self, sample: Any, transform) -> Any: + return self._apply_transform(sample, transform, "post_tensor_transform") + + def train_post_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_post_tensor_transform(source, self.train_transform), target + + def val_post_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_post_tensor_transform(source, self.valid_transform), target + + def test_post_tensor_transform(self, sample: Any) -> Any: + source, target = sample + return self.common_post_tensor_transform(source, self.test_transform), target + + def predict_post_tensor_transform(self, sample: Any) -> Any: + return self.common_post_tensor_transform(sample, self.predict_transform) + + def train_per_batch_transform_on_device(self, batch: Tuple) -> Tuple: + batch, target = batch + return self._apply_transform(batch, self.train_transform, "per_batch_transform_on_device"), target class ImageClassificationData(DataModule): """Data module for image classification tasks.""" - @classmethod - def from_filepaths( - cls, - train_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None, - train_labels: Optional[Sequence] = None, - train_transform: Optional[Callable] = _DEFAULT_TRAIN_TRANSFORMS, - valid_split: Union[None, float] = None, - valid_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None, - valid_labels: Optional[Sequence] = None, - valid_transform: Optional[Callable] = _DEFAULT_VALID_TRANSFORMS, - test_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None, - test_labels: Optional[Sequence] = None, - loader: Callable = _pil_loader, - batch_size: int = 64, + preprocess_cls = ImageClassificationPreprocess + image_size = (196, 196) + + def __init__( + self, + train_dataset: Optional[Dataset] = None, + valid_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + predict_dataset: Optional[Dataset] = None, + batch_size: int = 1, num_workers: Optional[int] = None, seed: int = 1234, + train_split: Optional[Union[float, int]] = None, + valid_split: Optional[Union[float, int]] = None, + test_split: Optional[Union[float, int]] = None, **kwargs, ) -> 'ImageClassificationData': - """Creates a ImageClassificationData object from lists of image filepaths and labels + """Creates a ImageClassificationData object from lists of image filepaths and labels""" - Args: - train_filepaths: string or sequence of file paths for training dataset. Defaults to ``None``. - train_labels: sequence of labels for training dataset. Defaults to ``None``. - train_transform: transforms for training dataset. Defaults to ``None``. - valid_split: if not None, generates val split from train dataloader using this value. - valid_filepaths: string or sequence of file paths for validation dataset. Defaults to ``None``. - valid_labels: sequence of labels for validation dataset. Defaults to ``None``. - valid_transform: transforms for validation and testing dataset. Defaults to ``None``. - test_filepaths: string or sequence of file paths for test dataset. Defaults to ``None``. - test_labels: sequence of labels for test dataset. Defaults to ``None``. - loader: function to load an image file. Defaults to ``None``. - batch_size: the batchsize to use for parallel loading. Defaults to ``64``. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - seed: Used for the train/val splits when valid_split + if train_dataset is not None and train_split is not None or valid_split is not None or test_split is not None: + train_dataset, valid_dataset, test_dataset = self.train_valid_test_split( + train_dataset, train_split, valid_split, test_split, seed + ) - Returns: - ImageClassificationData: The constructed data module. + super().__init__( + train_dataset=train_dataset, + valid_dataset=valid_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + batch_size=batch_size, + num_workers=num_workers, + ) - Examples: - >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP + self._num_classes = None - Example when labels are in .csv file:: + if self._train_ds: + self.set_dataset_attribute(self._train_ds, 'num_classes', self.num_classes) - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') - valid_labels = labels_from_categorical_csv(path/to/valid.csv', 'my_id') - test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') + if self._valid_ds: + self.set_dataset_attribute(self._valid_ds, 'num_classes', self.num_classes) - data = ImageClassificationData.from_filepaths( - batch_size=2, - train_filepaths='path/to/train', - train_labels=train_labels, - valid_filepaths='path/to/valid', - valid_labels=valid_labels, - test_filepaths='path/to/test', - test_labels=test_labels, + if self._test_ds: + self.set_dataset_attribute(self._test_ds, 'num_classes', self.num_classes) + + if self._predict_ds: + self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes) + + @staticmethod + def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[str, Union[nn.Module, Callable]]: + if transform and not isinstance(transform, Dict): + raise MisconfigurationException( + "Transform should be a dict. " + f"Here are the available keys for your transforms: {DataPipeline.PREPROCESS_FUNCS}." ) + return transform + @staticmethod + def default_train_transforms(): + image_size = ImageClassificationData.image_size + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), + "per_batch_transform_on_device": nn.Sequential( + K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + from torchvision import transforms as T # noqa F811 + return { + "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), + "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + } + + @staticmethod + def default_valid_transforms(): + image_size = ImageClassificationData.image_size + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), + "per_batch_transform_on_device": nn.Sequential( + K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + from torchvision import transforms as T # noqa F811 + return { + "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]), + "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + } + + @property + def num_classes(self) -> int: + if self._num_classes is None: + if self._train_ds is not None: + self._num_classes = self._get_num_classes(self._train_ds) + + return self._num_classes + + def _get_num_classes(self, dataset: torch.utils.data.Dataset): + num_classes = self.get_dataset_attribute(dataset, "num_classes", None) + if num_classes is None: + num_classes = torch.tensor([dataset[idx][1] for idx in range(len(dataset))]).unique().numel() + + return num_classes + + @classmethod + def instantiate_preprocess( + cls, + train_transform: Dict[str, Union[nn.Module, Callable]], + valid_transform: Dict[str, Union[nn.Module, Callable]], + test_transform: Dict[str, Union[nn.Module, Callable]], + predict_transform: Dict[str, Union[nn.Module, Callable]], + preprocess_cls: Type[Preprocess] = None + ) -> Preprocess: """ - # enable passing in a string which loads all files in that folder as a list - if isinstance(train_filepaths, str): - train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] - if isinstance(valid_filepaths, str): - valid_filepaths = [os.path.join(valid_filepaths, x) for x in os.listdir(valid_filepaths)] - if isinstance(test_filepaths, str): - test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] + This function is used to instantiate ImageClassificationData preprocess object. - train_ds = FilepathDataset( - filepaths=train_filepaths, - labels=train_labels, - loader=loader, - transform=train_transform, + Args: + train_transform: Train transforms for images. + valid_transform: Validation transforms for images. + test_transform: Test transforms for images. + predict_transform: Predict transforms for images. + preprocess_cls: User provided preprocess_cls. + + Example:: + + train_transform = { + "per_sample_transform": T.Compose([ + T.RandomResizedCrop(224), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]), + "per_batch_transform_on_device": nn.Sequential(K.RandomAffine(360), K.ColorJitter(0.2, 0.3, 0.2, 0.3)) + } + + """ + train_transform, valid_transform, test_transform, predict_transform = cls._resolve_transforms( + train_transform, valid_transform, test_transform, predict_transform ) - if valid_split: - full_length = len(train_ds) - train_split = int((1.0 - valid_split) * full_length) - valid_split = full_length - train_split - train_ds, valid_ds = torch.utils.data.random_split( - train_ds, [train_split, valid_split], generator=torch.Generator().manual_seed(seed) - ) - else: - valid_ds = ( - FilepathDataset( - filepaths=valid_filepaths, - labels=valid_labels, - loader=loader, - transform=valid_transform, - ) if valid_filepaths else None - ) + preprocess_cls = preprocess_cls or cls.preprocess_cls + return preprocess_cls(train_transform, valid_transform, test_transform, predict_transform) - test_ds = ( - FilepathDataset( - filepaths=test_filepaths, - labels=test_labels, - loader=loader, - transform=valid_transform, - ) if test_filepaths else None - ) + @classmethod + def _resolve_transforms( + cls, + train_transform: Optional[Union[str, Dict]] = 'default', + valid_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + ): - return cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, - batch_size=batch_size, - num_workers=num_workers, + if not train_transform or train_transform == 'default': + train_transform = cls.default_train_transforms() + + if not valid_transform or valid_transform == 'default': + valid_transform = cls.default_valid_transforms() + + if not test_transform or test_transform == 'default': + test_transform = cls.default_valid_transforms() + + if not predict_transform or predict_transform == 'default': + predict_transform = cls.default_valid_transforms() + + return ( + cls._check_transforms(train_transform), cls._check_transforms(valid_transform), + cls._check_transforms(test_transform), cls._check_transforms(predict_transform) ) @classmethod def from_folders( cls, - train_folder: Optional[Union[str, pathlib.Path]], - train_transform: Optional[Callable] = _DEFAULT_TRAIN_TRANSFORMS, + train_folder: Optional[Union[str, pathlib.Path]] = None, valid_folder: Optional[Union[str, pathlib.Path]] = None, - valid_transform: Optional[Callable] = _DEFAULT_VALID_TRANSFORMS, test_folder: Optional[Union[str, pathlib.Path]] = None, - loader: Callable = _pil_loader, + predict_folder: Union[str, pathlib.Path] = None, + train_transform: Optional[Union[str, Dict]] = 'default', + valid_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', batch_size: int = 4, num_workers: Optional[int] = None, + preprocess_cls: Optional[Type[Preprocess]] = None, **kwargs, - ) -> 'ImageClassificationData': + ) -> 'DataModule': """ Creates a ImageClassificationData object from folders of images arranged in this way: :: @@ -387,12 +422,15 @@ def from_folders( train/cat/asd932.png Args: - train_folder: Path to training folder. - train_transform: Image transform to use for training set. - valid_folder: Path to validation folder. + train_folder: Path to training folder. Default: None. + valid_folder: Path to validation folder. Default: None. + test_folder: Path to test folder. Default: None. + predict_folder: Path to predict folder. Default: None. valid_transform: Image transform to use for validation and test set. - test_folder: Path to test folder. - loader: A function to load an image given its path. + train_transform: Image transform to use for training set. + valid_transform: Image transform to use for validation set. + test_transform: Image transform to use for test set. + predict_transform: Image transform to use for predict set. batch_size: Batch size for data loading. num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. @@ -404,35 +442,40 @@ def from_folders( >>> img_data = ImageClassificationData.from_folders("train/") # doctest: +SKIP """ - train_ds = FlashDatasetFolder(train_folder, transform=train_transform, loader=loader) - valid_ds = ( - FlashDatasetFolder(valid_folder, transform=valid_transform, loader=loader) if valid_folder else None + preprocess = cls.instantiate_preprocess( + train_transform, + valid_transform, + test_transform, + predict_transform, + preprocess_cls=preprocess_cls, ) - test_ds = (FlashDatasetFolder(test_folder, transform=valid_transform, loader=loader) if test_folder else None) - - datamodule = cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, + return cls.from_load_data_inputs( + train_load_data_input=train_folder, + valid_load_data_input=valid_folder, + test_load_data_input=test_folder, + predict_load_data_input=predict_folder, batch_size=batch_size, num_workers=num_workers, + preprocess=preprocess, + **kwargs, ) - datamodule.num_classes = len(train_ds.classes) - datamodule.data_pipeline = ImageClassificationDataPipeline( - train_transform=train_transform, valid_transform=valid_transform, loader=loader - ) - return datamodule - @classmethod - def from_folder( + def from_filepaths( cls, - folder: Union[str, pathlib.Path], - transform: Optional[Callable] = _DEFAULT_VALID_TRANSFORMS, - loader: Callable = _pil_loader, + train_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, + train_labels: Optional[Sequence] = None, + valid_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, + valid_labels: Optional[Sequence] = None, + test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, + test_labels: Optional[Sequence] = None, + predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, + train_transform: Optional[Callable] = 'default', + valid_transform: Optional[Callable] = 'default', batch_size: int = 64, num_workers: Optional[int] = None, + seed: Optional[int] = 42, **kwargs, ) -> 'ImageClassificationData': """ @@ -446,51 +489,102 @@ def from_folder( folder/cat_asd932_.png Args: - folder: Path to the data folder. - transform: Image transform to apply to the data. - loader: A function to load an image given its path. - batch_size: Batch size for data loading. + train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. + train_labels: Sequence of labels for training dataset. Defaults to ``None``. + valid_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. + valid_labels: Sequence of labels for validation dataset. Defaults to ``None``. + test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. + test_labels: Sequence of labels for test dataset. Defaults to ``None``. + train_transform: Transforms for training dataset. Defaults to ``default``, which loads imagenet transforms. + valid_transform: Transforms for validation and testing dataset. + Defaults to ``default``, which loads imagenet transforms. + batch_size: The batchsize to use for parallel loading. Defaults to ``64``. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads. + Defaults to ``None`` which equals the number of available CPU threads. + seed: Used for the train/val splits when valid_split is not None. Returns: - ImageClassificationData: the constructed data module + ImageClassificationData: The constructed data module. Examples: - >>> img_data = ImageClassificationData.from_folder("folder/") # doctest: +SKIP + >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP + + Example when labels are in .csv file:: + + train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') + valid_labels = labels_from_categorical_csv(path/to/valid.csv', 'my_id') + test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') + + data = ImageClassificationData.from_filepaths( + batch_size=2, + train_filepaths='path/to/train', + train_labels=train_labels, + valid_filepaths='path/to/valid', + valid_labels=valid_labels, + test_filepaths='path/to/test', + test_labels=test_labels, + ) """ - if not os.path.isdir(folder): - raise MisconfigurationException("folder should be a directory") + # enable passing in a string which loads all files in that folder as a list + if isinstance(train_filepaths, str): + if os.path.isdir(train_filepaths): + train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] + else: + train_filepaths = [train_filepaths] + if isinstance(valid_filepaths, str): + if os.path.isdir(valid_filepaths): + valid_filepaths = [os.path.join(valid_filepaths, x) for x in os.listdir(valid_filepaths)] + else: + valid_filepaths = [valid_filepaths] + if isinstance(test_filepaths, str): + if os.path.isdir(test_filepaths): + test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] + else: + test_filepaths = [test_filepaths] + if isinstance(predict_filepaths, str): + if os.path.isdir(predict_filepaths): + predict_filepaths = [os.path.join(predict_filepaths, x) for x in os.listdir(predict_filepaths)] + else: + predict_filepaths = [predict_filepaths] + + if train_filepaths is not None and train_labels is not None: + train_dataset = cls._generate_dataset_if_possible( + list(zip(train_filepaths, train_labels)), running_stage=RunningStage.TRAINING + ) + else: + train_dataset = None - filenames = os.listdir(folder) + if valid_filepaths is not None and valid_labels is not None: + valid_dataset = cls._generate_dataset_if_possible( + list(zip(valid_filepaths, valid_labels)), running_stage=RunningStage.VALIDATING + ) + else: + valid_dataset = None - if any(not has_file_allowed_extension(f, IMG_EXTENSIONS) for f in filenames): - raise MisconfigurationException( - "No images with allowed extensions {IMG_EXTENSIONS} where found in {folder}" + if test_filepaths is not None and test_labels is not None: + test_dataset = cls._generate_dataset_if_possible( + list(zip(test_filepaths, test_labels)), running_stage=RunningStage.TESTING ) + else: + test_dataset = None - test_ds = ( - FlashDatasetFolder( - folder, - transform=transform, - loader=loader, - with_targets=False, - img_paths=[os.path.join(folder, f) for f in filenames] + if predict_filepaths is not None: + predict_dataset = cls._generate_dataset_if_possible( + predict_filepaths, running_stage=RunningStage.PREDICTING ) - ) + else: + predict_dataset = None - datamodule = cls( - test_ds=test_ds, + return cls( + train_dataset=train_dataset, + valid_dataset=valid_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + train_transform=train_transform, + valid_transform=valid_transform, batch_size=batch_size, num_workers=num_workers, - ) - - datamodule.data_pipeline = ImageClassificationDataPipeline(valid_transform=transform, loader=loader) - return datamodule - - @staticmethod - def default_pipeline() -> ImageClassificationDataPipeline: - return ImageClassificationDataPipeline( - train_transform=_DEFAULT_TRAIN_TRANSFORMS, valid_transform=_DEFAULT_VALID_TRANSFORMS, loader=_pil_loader + seed=seed, + **kwargs ) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 1bbeeda020..f3774616c4 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -16,12 +16,11 @@ import torch from torch import nn from torch.nn import functional as F -from torch.nn.functional import softmax from torchmetrics import Accuracy from flash.core.classification import ClassificationTask from flash.vision.backbones import backbone_and_num_features -from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline +from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess class ImageClassifier(ClassificationTask): @@ -59,6 +58,10 @@ class ImageClassifier(ClassificationTask): learning_rate: Learning rate to use for training, defaults to ``1e-3``. """ + @property + def preprocess(self): + return ImageClassificationPreprocess(predict_transform=ImageClassificationData.default_valid_transforms()) + def __init__( self, num_classes: int, @@ -92,8 +95,4 @@ def __init__( def forward(self, x) -> Any: x = self.backbone(x) - return softmax(self.head(x)) - - @staticmethod - def default_pipeline() -> ImageClassificationDataPipeline: - return ImageClassificationData.default_pipeline() + return torch.softmax(self.head(x), -1) diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 52cddbad83..b3a32b3a35 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Type import torch from PIL import Image -from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor, tensor from torch._six import container_abcs from torch.utils.data._utils.collate import default_collate from torchvision import transforms as T -from flash.core.data import TaskDataPipeline -from flash.core.data.datamodule import DataModule -from flash.core.data.utils import _contains_any_tensor -from flash.vision.classification.data import _pil_loader +from flash.data.auto_dataset import AutoDataset +from flash.data.data_module import DataModule +from flash.data.process import Preprocess +from flash.data.utils import _contains_any_tensor +from flash.utils.imports import _COCO_AVAILABLE +from flash.vision.utils import pil_loader -_COCO_AVAILABLE = _module_available("pycocotools") if _COCO_AVAILABLE: from pycocotools.coco import COCO @@ -40,6 +40,7 @@ def __init__( root: str, ann_file: str, transforms: Optional[Callable] = None, + loader: Optional[Callable] = pil_loader, ): if not _COCO_AVAILABLE: raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset") @@ -48,9 +49,10 @@ def __init__( self.transforms = transforms self.coco = COCO(ann_file) self.ids = list(sorted(self.coco.imgs.keys())) + self.loader = loader @property - def num_classes(self): + def num_classes(self) -> int: categories = self.coco.loadCats(self.coco.getCatIds()) if not categories: raise ValueError("No Categories found") @@ -85,12 +87,13 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: areas.append(obj["area"]) iscrowd.append(obj["iscrowd"]) - target = {} - target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32) - target["labels"] = torch.as_tensor(labels, dtype=torch.int64) - target["image_id"] = tensor([img_idx]) - target["area"] = torch.as_tensor(areas, dtype=torch.float32) - target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64) + target = dict( + boxes=torch.as_tensor(boxes, dtype=torch.float32), + labels=torch.as_tensor(labels, dtype=torch.int64), + image_id=tensor([img_idx]), + area=torch.as_tensor(areas, dtype=torch.float32), + iscrowd=torch.as_tensor(iscrowd, dtype=torch.int64) + ) if self.transforms: img = self.transforms(img) @@ -130,13 +133,23 @@ def _has_valid_annotation(annot: List): _default_transform = T.ToTensor() -class ObjectDetectionDataPipeline(TaskDataPipeline): +class ObjectDetectionPreprocess(Preprocess): + + to_tensor = T.ToTensor() + + def load_data(self, metadata: Any, dataset: AutoDataset) -> CustomCOCODataset: + # Extract folder, coco annotation file and the transform to be applied on the images + folder, ann_file, transform = metadata + ds = CustomCOCODataset(folder, ann_file, transform) + if self.training: + dataset.num_classes = ds.num_classes + ds = _coco_remove_images_without_annotations(ds) + return ds - def __init__(self, valid_transform: Optional[Callable] = _default_transform, loader: Callable = _pil_loader): - self._valid_transform = valid_transform - self._loader = loader + def predict_load_data(self, samples): + return samples - def before_collate(self, samples: Any) -> Any: + def pre_tensor_transform(self, samples: Any) -> Any: if _contains_any_tensor(samples): return samples @@ -146,11 +159,13 @@ def before_collate(self, samples: Any) -> Any: if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): outputs = [] for sample in samples: - output = self._loader(sample) - outputs.append(self._valid_transform(output)) + outputs.append(pil_loader(sample)) return outputs raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + def predict_to_tensor_transform(self, sample) -> Any: + return self.to_tensor(sample[0]) + def collate(self, samples: Any) -> Any: if not isinstance(samples, Tensor): elem = samples[0] @@ -162,6 +177,19 @@ def collate(self, samples: Any) -> Any: class ObjectDetectionData(DataModule): + preprocess_cls = ObjectDetectionPreprocess + + @classmethod + def instantiate_preprocess( + cls, + train_transform: Optional[Callable], + valid_transform: Optional[Callable], + preprocess_cls: Type[Preprocess] = None + ) -> Preprocess: + + preprocess_cls = preprocess_cls or cls.preprocess_cls + return preprocess_cls(train_transform, valid_transform) + @classmethod def from_coco( cls, @@ -176,24 +204,20 @@ def from_coco( test_transform: Optional[Callable] = _default_transform, batch_size: int = 4, num_workers: Optional[int] = None, + preprocess_cls: Type[Preprocess] = None, **kwargs ): - train_ds = CustomCOCODataset(train_folder, train_ann_file, train_transform) - num_classes = train_ds.num_classes - train_ds = _coco_remove_images_without_annotations(train_ds) - valid_ds = (CustomCOCODataset(valid_folder, valid_ann_file, valid_transform) if valid_folder else None) + preprocess = cls.instantiate_preprocess(train_transform, valid_transform, preprocess_cls=preprocess_cls) - test_ds = (CustomCOCODataset(test_folder, test_ann_file, test_transform) if test_folder else None) - - datamodule = cls( - train_ds=train_ds, - valid_ds=valid_ds, - test_ds=test_ds, + datamodule = cls.from_load_data_inputs( + train_load_data_input=(train_folder, train_ann_file, train_transform), + valid_load_data_input=(valid_folder, valid_ann_file, valid_transform) if valid_folder else None, + test_load_data_input=(test_folder, test_ann_file, test_transform) if test_folder else None, batch_size=batch_size, num_workers=num_workers, + preprocess=preprocess, + **kwargs ) - - datamodule.num_classes = num_classes - datamodule.data_pipeline = ObjectDetectionDataPipeline() + datamodule.num_classes = datamodule._train_ds.num_classes return datamodule diff --git a/flash/vision/detection/finetuning.py b/flash/vision/detection/finetuning.py index fd5f49368e..c1ca20072d 100644 --- a/flash/vision/detection/finetuning.py +++ b/flash/vision/detection/finetuning.py @@ -21,8 +21,8 @@ class ObjectDetectionFineTuning(FlashBaseFinetuning): Freezes the backbone during Detector training. """ - def __init__(self, train_bn: bool = True): - self.train_bn = train_bn + def __init__(self, train_bn: bool = True) -> None: + super().__init__(train_bn=train_bn) def freeze_before_training(self, pl_module: pl.LightningModule) -> None: model = pl_module.model diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index 60b2cf2665..897cd124b8 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -24,7 +24,6 @@ from flash.core import Task from flash.vision.backbones import backbone_and_num_features -from flash.vision.detection.data import ObjectDetectionDataPipeline from flash.vision.detection.finetuning import ObjectDetectionFineTuning _models = { @@ -189,9 +188,5 @@ def test_epoch_end(self, outs): logs = {"test_iou": avg_iou} return {"avg_test_iou": avg_iou, "log": logs} - @staticmethod - def default_pipeline() -> ObjectDetectionDataPipeline: - return ObjectDetectionDataPipeline() - def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/flash/vision/embedding/__init__.py b/flash/vision/embedding/__init__.py index 8d3ebf8c27..962ffeffe2 100644 --- a/flash/vision/embedding/__init__.py +++ b/flash/vision/embedding/__init__.py @@ -1 +1 @@ -from flash.vision.embedding.image_embedder_model import ImageEmbedder, ImageEmbedderDataPipeline +from flash.vision.embedding.model import ImageEmbedder diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/model.py similarity index 68% rename from flash/vision/embedding/image_embedder_model.py rename to flash/vision/embedding/model.py index 04ad142912..67cca42b4b 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/model.py @@ -15,54 +15,14 @@ import torch from pytorch_lightning.utilities.distributed import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F from torchmetrics import Accuracy from flash.core import Task -from flash.core.data import TaskDataPipeline -from flash.core.data.utils import _contains_any_tensor +from flash.data.data_pipeline import DataPipeline from flash.vision.backbones import backbone_and_num_features -from flash.vision.classification.data import _DEFAULT_VALID_TRANSFORMS, _pil_loader - - -class ImageEmbedderDataPipeline(TaskDataPipeline): - """ - >>> from flash.vision.embedding import ImageEmbedderDataPipeline - >>> iedata = ImageEmbedderDataPipeline() - >>> iedata.before_collate(torch.tensor([1])) - tensor([1]) - >>> import os, numpy, PIL - >>> img = PIL.Image.fromarray(numpy.random.randint(0, 255, (150, 200, 3)), 'RGB') - >>> img.save('sample-image.png') - >>> iedata.before_collate('sample-image.png') # doctest: +ELLIPSIS - [tensor([[[...]]])] - >>> os.remove('sample-image.png') - """ - - def __init__( - self, - valid_transform: Optional[Callable] = _DEFAULT_VALID_TRANSFORMS, - loader: Callable = _pil_loader, - ): - self._valid_transform = valid_transform - self._loader = loader - - def before_collate(self, samples: Any) -> Any: - if _contains_any_tensor(samples): - return samples - - if isinstance(samples, str): - samples = [samples] - - if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): - outputs = [] - for sample in samples: - output = self._loader(sample) - outputs.append(self._valid_transform(output)) - return outputs - raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") +from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess class ImageEmbedder(Task): @@ -87,6 +47,10 @@ class ImageEmbedder(Task): """ + @property + def preprocess(self): + return ImageClassificationPreprocess(predict_transform=ImageClassificationData.default_valid_transforms()) + def __init__( self, embedding_dim: Optional[int] = None, @@ -145,7 +109,3 @@ def forward(self, x) -> Any: x = self.head(x) return x - - @staticmethod - def default_pipeline() -> ImageEmbedderDataPipeline: - return ImageEmbedderDataPipeline() diff --git a/flash/vision/utils.py b/flash/vision/utils.py new file mode 100644 index 0000000000..d40467fcf7 --- /dev/null +++ b/flash/vision/utils.py @@ -0,0 +1,22 @@ +from typing import List, Tuple, Union + +from PIL import Image + + +def pil_loader(sample: Union[List, Tuple, str]) -> Union[Image.Image, list]: + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + + if isinstance(sample, (tuple, list)): + path = sample[0] + sample = list(sample) + else: + path = sample + + with open(path, "rb") as f, Image.open(f) as img: + img = img.convert("RGB") + + if isinstance(sample, list): + sample[0] = img + return sample + + return img diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index f4f2b596e7..a21090f66c 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash -from flash.core.data import download_data +from flash import Trainer from flash.core.finetuning import FreezeUnfreeze +from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data @@ -25,18 +26,29 @@ valid_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) - # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) -# 4. Create the trainer. Run twice on data -trainer = flash.Trainer(max_epochs=2) +# 4. Create the trainer. +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) -# 6. Test the model -trainer.test() +# 3a. Predict what's on a few images! ants or bees? +predictions = model.predict([ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", + "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", +]) + +print(predictions) + +datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") + +# 3b. Or generate predictions with a whole folder! +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) -# 7. Save it! +# 4. Save it! trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash_examples/finetuning/object_detection.py b/flash_examples/finetuning/object_detection.py index 96b97003f3..4d013c37ac 100644 --- a/flash_examples/finetuning/object_detection.py +++ b/flash_examples/finetuning/object_detection.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash -from flash.core.data import download_data +from flash.data.utils import download_data from flash.vision import ObjectDetectionData, ObjectDetector # 1. Download the data @@ -29,7 +29,7 @@ # 3. Build the model model = ObjectDetector(num_classes=datamodule.num_classes) -# 4. Create the trainer. Run twice on data +# 4. Create the trainer trainer = flash.Trainer(max_epochs=3) # 5. Finetune the model diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index e8ac6d8fcf..d25efa697a 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash -from flash import download_data +from flash import download_data, Trainer from flash.text import SummarizationData, SummarizationTask # 1. Download the data @@ -31,13 +33,10 @@ model = SummarizationTask() # 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1) +trainer = flash.Trainer(gpus=int(torch.cuda.is_available()), fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) -# 6. Test model -trainer.test() - -# 7. Save it! +# 6. Save it! trainer.save_checkpoint("summarization_model_xsum.pt") diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index 1e72cb22f7..9d5b8ad256 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -14,7 +14,7 @@ from torchmetrics.classification import Accuracy, Precision, Recall import flash -from flash.core.data import download_data +from flash.data.utils import download_data from flash.tabular import TabularClassifier, TabularData # 1. Download the data @@ -22,25 +22,25 @@ # 2. Load the data datamodule = TabularData.from_csv( - "./data/titanic/titanic.csv", + target_col="Survived", + train_csv="./data/titanic/titanic.csv", test_csv="./data/titanic/test.csv", - categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_input=["Fare"], - target="Survived", + categorical_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + numerical_cols=["Fare"], val_size=0.25, ) # 3. Build the model model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()]) -# 4. Create the trainer. Run 10 times on data -trainer = flash.Trainer(max_epochs=1) +# 4. Create the trainer +trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.fit(model, datamodule=datamodule) # 6. Test model -trainer.test() +trainer.test(model) # 7. Save it! trainer.save_checkpoint("tabular_classification_model.pt") diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 4b5155b62d..0c02be354b 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash -from flash.core.data import download_data +from flash.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data @@ -25,20 +25,20 @@ test_file="data/imdb/test.csv", input="review", target="sentiment", - batch_size=512 + batch_size=16 ) # 3. Build the model model = TextClassifier(num_classes=datamodule.num_classes) -# 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1) +# 4. Create the trainer +trainer = flash.Trainer(fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Test model -trainer.test() +trainer.test(model) # 7. Save it! trainer.save_checkpoint("text_classification_model.pt") diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index d7a4c043eb..c057ec4790 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash import download_data from flash.text import TranslationData, TranslationTask @@ -25,19 +27,20 @@ test_file="data/wmt_en_ro/test.csv", input="input", target="target", + batch_size=1 ) # 3. Build the model model = TranslationTask() -# 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1) +# 4. Create the trainer +trainer = flash.Trainer(precision=32, gpus=int(torch.cuda.is_available()), fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) # 6. Test model -trainer.test() +trainer.test(model) # 7. Save it! trainer.save_checkpoint("translation_model_en_ro.pt") diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index 755f2bbd89..ec92fcb90e 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -19,7 +19,7 @@ from torchvision import datasets, transforms from flash import ClassificationTask -from flash.core.data import download_data +from flash.data.utils import download_data _PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) @@ -36,7 +36,7 @@ ) # 3. Load a dataset -dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor()) +dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=False, transform=transforms.ToTensor()) # 4. Split the data randomly train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore diff --git a/flash_examples/predict/classify_image.py b/flash_examples/predict/image_classification.py similarity index 90% rename from flash_examples/predict/classify_image.py rename to flash_examples/predict/image_classification.py index 82b21b588b..fda4a5c71a 100644 --- a/flash_examples/predict/classify_image.py +++ b/flash_examples/predict/image_classification.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash import Trainer -from flash.core.data import download_data +from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data @@ -30,6 +30,6 @@ print(predictions) # 3b. Or generate predictions with a whole folder! -datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/") +datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 4189285bd3..04bb155361 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from flash.core.data import download_data +from flash.data.utils import download_data from flash.vision import ImageEmbedder # 1. Download the data @@ -27,13 +27,13 @@ embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) # 4. Print embeddings shape -print(embeddings.shape) +print(embeddings[0].shape) # 5. Create a tensor random image -random_image = torch.randn(1, 3, 32, 32) +random_image = torch.randn(1, 3, 244, 244) # 6. Generate an embedding from this random image. embeddings = embedder.predict(random_image) # 7. Print embeddings shape -print(embeddings.shape) +print(embeddings[0].shape) diff --git a/flash_examples/predict/summarize.py b/flash_examples/predict/summarization.py similarity index 97% rename from flash_examples/predict/summarize.py rename to flash_examples/predict/summarization.py index 172a7e67da..6d16ebfcaf 100644 --- a/flash_examples/predict/summarize.py +++ b/flash_examples/predict/summarization.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning import Trainer -from flash.core.data import download_data +from flash.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download the data @@ -48,7 +48,7 @@ print(predictions) # 2b. Or generate summaries from a sheet file! -datamodule = SummarizationData.from_file( +datamodule = SummarizationData.from_files( predict_file="data/xsum/predict.csv", input="input", ) diff --git a/flash_examples/predict/classify_tabular.py b/flash_examples/predict/tabular_classification.py similarity index 90% rename from flash_examples/predict/classify_tabular.py rename to flash_examples/predict/tabular_classification.py index cb2772361f..71094a5e9e 100644 --- a/flash_examples/predict/classify_tabular.py +++ b/flash_examples/predict/tabular_classification.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.core.data import download_data +from flash.data.utils import download_data from flash.tabular import TabularClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the model from a checkpoint -model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt") +model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") # 3. Generate predictions from a sheet file! Who would survive? predictions = model.predict("data/titanic/titanic.csv") diff --git a/flash_examples/predict/classify_text.py b/flash_examples/predict/text_classification.py similarity index 90% rename from flash_examples/predict/classify_text.py rename to flash_examples/predict/text_classification.py index 9b4a74d30a..00029e3fae 100644 --- a/flash_examples/predict/classify_text.py +++ b/flash_examples/predict/text_classification.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning import Trainer -from flash.core.data import download_data +from flash.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data @@ -36,6 +36,8 @@ datamodule = TextClassificationData.from_file( predict_file="data/imdb/predict.csv", input="review", + # use the same data pre-processing values we used to predict in 2a + preprocess_state=model.data_pipeline.preprocess_state, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/predict/translate.py b/flash_examples/predict/translation.py similarity index 100% rename from flash_examples/predict/translate.py rename to flash_examples/predict/translation.py diff --git a/flash_notebooks/custom_task_tutorial.ipynb b/flash_notebooks/custom_task_tutorial.ipynb index 312144c2fe..520f989b0c 100644 --- a/flash_notebooks/custom_task_tutorial.ipynb +++ b/flash_notebooks/custom_task_tutorial.ipynb @@ -101,11 +101,14 @@ "metadata": {}, "outputs": [], "source": [ - "class DiabetesPipeline(flash.core.data.TaskDataPipeline):\n", - " def after_uncollate(self, samples):\n", + "class DiabetesPipeline(flash.data.process.Postprocess):\n", + " def per_sample_transform(self, samples):\n", " return [f\"disease progression: {float(s):.2f}\" for s in samples]\n", "\n", "class DiabetesData(flash.DataModule):\n", + " \n", + " postprocess_cls = DiabetesPipeline\n", + " \n", " def __init__(self, batch_size=64, num_workers=0):\n", " x, y = datasets.load_diabetes(return_X_y=True)\n", " x = torch.from_numpy(x).float()\n", @@ -121,11 +124,7 @@ " batch_size=batch_size,\n", " num_workers=num_workers\n", " )\n", - " self.num_inputs = x.shape[1]\n", - " \n", - " @staticmethod\n", - " def default_pipeline():\n", - " return DiabetesPipeline() " + " self.num_inputs = x.shape[1] " ] }, { @@ -158,7 +157,7 @@ "data = DiabetesData()\n", "model = LinearRegression(num_inputs=data.num_inputs)\n", "\n", - "trainer = flash.Trainer(max_epochs=1000)\n", + "trainer = flash.Trainer(max_epochs=10, progress_bar_refresh_rate=20)\n", "trainer.fit(model, data)" ] }, @@ -191,13 +190,53 @@ "source": [ "Because of our custom data pipeline's `after_uncollate` method, we will get a nicely formatted output like the following:\n", "```\n", - "['disease progression: 155.90',\n", - " 'disease progression: 156.59',\n", - " 'disease progression: 152.69',\n", - " 'disease progression: 149.05',\n", - " 'disease progression: 150.90']\n", + "[['disease progression: 14.84'],\n", + " ['disease progression: 14.86'],\n", + " ['disease progression: 14.78'],\n", + " ['disease progression: 14.73'],\n", + " ['disease progression: 14.71']]\n", "```" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Help us build Flash by adding support for new data-types and new tasks.\n", + "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", + "If you are interested, please open a PR with your contributions !!! \n", + "\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] } ], "metadata": { @@ -216,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/flash_notebooks/generic_task.ipynb b/flash_notebooks/generic_task.ipynb index 51a272ae1f..4e7b25e465 100644 --- a/flash_notebooks/generic_task.ipynb +++ b/flash_notebooks/generic_task.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "outstanding-knight", + "id": "determined-vinyl", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "british-auckland", + "id": "fabulous-alfred", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by creating a ClassificationTask with a custom Convolutional Model and train it on [MNIST Dataset](http://yann.lecun.com/exdb/mnist/)\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "chicken-bradford", + "id": "pleased-produce", "metadata": {}, "source": [ "# Training" @@ -34,8 +34,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "impaired-trick", + "execution_count": 8, + "id": "straight-vision", "metadata": {}, "outputs": [], "source": [ @@ -45,8 +45,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "technological-certification", + "execution_count": 9, + "id": "foreign-design", "metadata": {}, "outputs": [], "source": [ @@ -58,9 +58,24 @@ "from flash import ClassificationTask" ] }, + { + "cell_type": "code", + "execution_count": 10, + "id": "mathematical-barbados", + "metadata": {}, + "outputs": [], + "source": [ + "from six.moves import urllib\n", + "\n", + "# TorchVision hotfix https://github.com/pytorch/vision/issues/1938\n", + "opener = urllib.request.build_opener()\n", + "opener.addheaders = [('User-agent', 'Mozilla/5.0')]\n", + "urllib.request.install_opener(opener)\n" + ] + }, { "cell_type": "markdown", - "id": "several-board", + "id": "alone-scenario", "metadata": {}, "source": [ "### 1. Load a basic backbone" @@ -68,8 +83,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "upper-quest", + "execution_count": 11, + "id": "selective-request", "metadata": {}, "outputs": [], "source": [ @@ -83,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "faced-captain", + "id": "innocent-african", "metadata": {}, "source": [ "### 2. Load a dataset" @@ -91,17 +106,48 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "welcome-hammer", + "execution_count": 12, + "id": "stunning-anime", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "ename": "HTTPError", + "evalue": "HTTP Error 503: Service Unavailable", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMNIST\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./data'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresources\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpartition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# process and save as torch files\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mdownload_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0;31m# check integrity of downloaded file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_integrity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Downloading '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0murl\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'https'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/GitHub/lightning-flash/.venv/lib/python3.8/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36m_urlretrieve\u001b[0;34m(url, filename, chunk_size)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunk_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1024\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murlopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"User-Agent\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUSER_AGENT\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0mopener\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_opener\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mopener\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minstall_opener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprocessor\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_response\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprotocol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mmeth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprocessor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 531\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeth\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 532\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_response\u001b[0;34m(self, request, response)\u001b[0m\n\u001b[1;32m 638\u001b[0m \u001b[0;31m# request was successfully received, understood, and accepted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mcode\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m300\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m response = self.parent.error(\n\u001b[0m\u001b[1;32m 641\u001b[0m 'http', request, response, code, msg, hdrs)\n\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, proto, *args)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhttp_err\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'default'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'http_error_default'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0morig_args\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 569\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_chain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# XXX probably also want an abstract factory that knows when it makes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36m_call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhandler\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhandlers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 502\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.8.5/lib/python3.8/urllib/request.py\u001b[0m in \u001b[0;36mhttp_error_default\u001b[0;34m(self, req, fp, code, msg, hdrs)\u001b[0m\n\u001b[1;32m 647\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPDefaultErrorHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mhttp_error_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 649\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mHTTPError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_url\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhdrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 650\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mHTTPRedirectHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBaseHandler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mHTTPError\u001b[0m: HTTP Error 503: Service Unavailable" + ] + } + ], "source": [ "dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "markdown", - "id": "banned-gardening", + "id": "formal-teach", "metadata": {}, "source": [ "### 3. Split the data randomly" @@ -110,7 +156,7 @@ { "cell_type": "code", "execution_count": null, - "id": "southwest-muscle", + "id": "hispanic-independence", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "formal-carnival", + "id": "broken-flour", "metadata": {}, "source": [ "### 4. Create the model" @@ -128,7 +174,7 @@ { "cell_type": "code", "execution_count": null, - "id": "essential-community", + "id": "literary-destruction", "metadata": {}, "outputs": [], "source": [ @@ -137,7 +183,7 @@ }, { "cell_type": "markdown", - "id": "controlling-combination", + "id": "ancient-seller", "metadata": {}, "source": [ "### 5. Create the trainer" @@ -146,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "altered-wealth", + "id": "public-berlin", "metadata": {}, "outputs": [], "source": [ @@ -159,7 +205,7 @@ }, { "cell_type": "markdown", - "id": "worldwide-fashion", + "id": "advisory-soccer", "metadata": {}, "source": [ "### 6. Train the model" @@ -168,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "according-ebony", + "id": "neural-genre", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +223,7 @@ }, { "cell_type": "markdown", - "id": "spread-chambers", + "id": "greater-geneva", "metadata": {}, "source": [ "### 7. Test the model" @@ -186,7 +232,7 @@ { "cell_type": "code", "execution_count": null, - "id": "molecular-retention", + "id": "ideal-johnson", "metadata": {}, "outputs": [], "source": [ @@ -195,7 +241,7 @@ }, { "cell_type": "markdown", - "id": "charitable-night", + "id": "classified-cholesterol", "metadata": {}, "source": [ "# Predicting" @@ -204,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "continued-daisy", + "id": "academic-preference", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +259,7 @@ }, { "cell_type": "markdown", - "id": "nominated-found", + "id": "traditional-faculty", "metadata": {}, "source": [ "\n", @@ -269,7 +315,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb index 16cdd8e007..4bf6ba2aae 100644 --- a/flash_notebooks/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "transsexual-sense", + "id": "brutal-journalist", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "matched-chassis", + "id": "structural-literacy", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", @@ -43,17 +43,17 @@ { "cell_type": "code", "execution_count": null, - "id": "aboriginal-hacker", + "id": "explicit-spray", "metadata": {}, "outputs": [], "source": [ "%%capture\n", - "! pip install lightning-flash" + "! pip install git+https://github.com/PyTorchLightning/pytorch-flash.git" ] }, { "cell_type": "markdown", - "id": "preceding-sister", + "id": "boring-spanking", "metadata": {}, "source": [ "### The notebook runtime has to be re-started once Flash is installed." @@ -62,7 +62,7 @@ { "cell_type": "code", "execution_count": null, - "id": "grand-crossing", + "id": "boring-failing", "metadata": {}, "outputs": [], "source": [ @@ -75,18 +75,18 @@ { "cell_type": "code", "execution_count": null, - "id": "detailed-bikini", + "id": "domestic-correspondence", "metadata": {}, "outputs": [], "source": [ "import flash\n", - "from flash.core.data import download_data\n", + "from flash.data.utils import download_data\n", "from flash.vision import ImageClassificationData, ImageClassifier" ] }, { "cell_type": "markdown", - "id": "becoming-launch", + "id": "failing-violence", "metadata": {}, "source": [ "## 1. Download data\n", @@ -96,7 +96,7 @@ { "cell_type": "code", "execution_count": null, - "id": "missing-richmond", + "id": "registered-approval", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +105,7 @@ }, { "cell_type": "markdown", - "id": "necessary-fleet", + "id": "patent-syndication", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -128,7 +128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "japanese-think", + "id": "wicked-slope", "metadata": {}, "outputs": [], "source": [ @@ -141,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "intermediate-virus", + "id": "ancient-portal", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -153,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aquatic-modification", + "id": "lucky-hopkins", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "associate-poster", + "id": "finite-fleece", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -179,7 +179,7 @@ { "cell_type": "code", "execution_count": null, - "id": "least-python", + "id": "isolated-aurora", "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ }, { "cell_type": "markdown", - "id": "ethical-router", + "id": "available-making", "metadata": {}, "source": [ "### 5. Finetune the model" @@ -197,7 +197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aggregate-radius", + "id": "multiple-washer", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "postal-regard", + "id": "entire-north", "metadata": {}, "source": [ "### 6. Test the model" @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "expired-alarm", + "id": "impressive-participant", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +224,7 @@ }, { "cell_type": "markdown", - "id": "corrected-tomorrow", + "id": "crucial-allen", "metadata": {}, "source": [ "### 7. Save it!" @@ -233,7 +233,7 @@ { "cell_type": "code", "execution_count": null, - "id": "atlantic-compiler", + "id": "every-rochester", "metadata": {}, "outputs": [], "source": [ @@ -242,7 +242,7 @@ }, { "cell_type": "markdown", - "id": "improving-impact", + "id": "capital-career", "metadata": {}, "source": [ "# Predicting" @@ -250,7 +250,7 @@ }, { "cell_type": "markdown", - "id": "prostate-offset", + "id": "abandoned-cambridge", "metadata": {}, "source": [ "### 1. Load the model from a checkpoint" @@ -259,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aerial-manchester", + "id": "after-explorer", "metadata": {}, "outputs": [], "source": [ @@ -268,7 +268,7 @@ }, { "cell_type": "markdown", - "id": "bored-lover", + "id": "first-compatibility", "metadata": {}, "source": [ "### 2a. Predict what's on a few images! ants or bees?" @@ -277,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bigger-momentum", + "id": "danish-fundamentals", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +291,7 @@ }, { "cell_type": "markdown", - "id": "municipal-emergency", + "id": "personal-controversy", "metadata": {}, "source": [ "### 2b. Or generate predictions with a whole folder!" @@ -300,18 +300,18 @@ { "cell_type": "code", "execution_count": null, - "id": "bibliographic-parts", + "id": "tribal-noise", "metadata": {}, "outputs": [], "source": [ - "datamodule = ImageClassificationData.from_folder(folder=\"data/hymenoptera_data/predict/\")\n", + "datamodule = ImageClassificationData.from_folders(predict_folder=\"data/hymenoptera_data/predict/\")\n", "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", "print(predictions)" ] }, { "cell_type": "markdown", - "id": "surprised-heath", + "id": "unlimited-burden", "metadata": {}, "source": [ "\n", @@ -367,7 +367,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 460e996a14..3932ba7c09 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "inside-conditions", + "id": "upper-receipt", "metadata": {}, "source": [ "
\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "hispanic-typing", + "id": "herbal-commissioner", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "orange-currency", + "id": "married-failing", "metadata": {}, "source": [ "# Training" @@ -35,31 +35,31 @@ { "cell_type": "code", "execution_count": null, - "id": "textile-discovery", + "id": "innocent-bhutan", "metadata": {}, "outputs": [], "source": [ - "%%capture\n", - "! pip install lightning-flash" + "# %%capture\n", + "! pip install git+https://github.com/PyTorchLightning/pytorch-flash.git" ] }, { "cell_type": "code", "execution_count": null, - "id": "existing-clear", + "id": "expensive-chassis", "metadata": {}, "outputs": [], "source": [ "from torchmetrics.classification import Accuracy, Precision, Recall\n", "\n", "import flash\n", - "from flash.core.data import download_data\n", + "from flash.data.utils import download_data\n", "from flash.tabular import TabularClassifier, TabularData" ] }, { "cell_type": "markdown", - "id": "third-albuquerque", + "id": "virtual-supplier", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -69,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "social-maximum", + "id": "documented-humanitarian", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "informed-aaron", + "id": "german-grill", "metadata": {}, "source": [ "### 2. Load the data\n", @@ -90,23 +90,23 @@ { "cell_type": "code", "execution_count": null, - "id": "occasional-smell", + "id": "exempt-cholesterol", "metadata": {}, "outputs": [], "source": [ "datamodule = TabularData.from_csv(\n", - " \"./data/titanic/titanic.csv\",\n", + " train_csv=\"./data/titanic/titanic.csv\",\n", " test_csv=\"./data/titanic/test.csv\",\n", - " categorical_input=[\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", - " numerical_input=[\"Fare\"],\n", - " target=\"Survived\",\n", + " categorical_cols=[\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", + " numerical_cols=[\"Fare\"],\n", + " target_col=\"Survived\",\n", " val_size=0.25,\n", ")\n" ] }, { "cell_type": "markdown", - "id": "searching-hepatitis", + "id": "mineral-remove", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -117,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "checked-sleeve", + "id": "functioning-compilation", "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "markdown", - "id": "labeled-intranet", + "id": "practical-highland", "metadata": {}, "source": [ "### 4. Create the trainer. Run 10 times on data" @@ -135,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "tracked-centre", + "id": "pretty-layer", "metadata": {}, "outputs": [], "source": [ @@ -144,7 +144,7 @@ }, { "cell_type": "markdown", - "id": "warming-hospital", + "id": "proprietary-mitchell", "metadata": {}, "source": [ "### 5. Train the model" @@ -153,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "normal-institution", + "id": "advised-contact", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "rural-result", + "id": "parental-norwegian", "metadata": {}, "source": [ "### 6. Test model" @@ -171,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "lonely-comparison", + "id": "protective-scholar", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +180,7 @@ }, { "cell_type": "markdown", - "id": "parental-latvia", + "id": "operating-incident", "metadata": {}, "source": [ "### 7. Save it!" @@ -189,7 +189,7 @@ { "cell_type": "code", "execution_count": null, - "id": "educational-carter", + "id": "following-journalist", "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ }, { "cell_type": "markdown", - "id": "architectural-milton", + "id": "pointed-hunter", "metadata": {}, "source": [ "# Predicting" @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "contrary-funds", + "id": "homeless-warrior", "metadata": {}, "source": [ "### 8. Load the model from a checkpoint\n", @@ -217,17 +217,17 @@ { "cell_type": "code", "execution_count": null, - "id": "black-joining", + "id": "personalized-panel", "metadata": {}, "outputs": [], "source": [ "model = TabularClassifier.load_from_checkpoint(\n", - " \"https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt\")" + " \"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt\")" ] }, { "cell_type": "markdown", - "id": "cloudy-monitoring", + "id": "soviet-theta", "metadata": {}, "source": [ "### 9. Generate predictions from a sheet file! Who would survive?\n", @@ -238,7 +238,7 @@ { "cell_type": "code", "execution_count": null, - "id": "alone-mumbai", + "id": "minor-siemens", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +248,7 @@ { "cell_type": "code", "execution_count": null, - "id": "stainless-guitar", + "id": "martial-hundred", "metadata": {}, "outputs": [], "source": [ @@ -257,7 +257,7 @@ }, { "cell_type": "markdown", - "id": "eastern-tenant", + "id": "portuguese-ordering", "metadata": {}, "source": [ "\n", @@ -313,7 +313,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb index 3567b48a15..9ad20120ee 100644 --- a/flash_notebooks/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "convinced-sunrise", + "id": "instant-bruce", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "accomplished-essay", + "id": "orange-spread", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "proved-receptor", + "id": "generic-evaluation", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "revolutionary-limit", + "id": "academic-alpha", "metadata": {}, "outputs": [], "source": [ @@ -63,18 +63,18 @@ { "cell_type": "code", "execution_count": null, - "id": "descending-consequence", + "id": "historical-asthma", "metadata": {}, "outputs": [], "source": [ "import flash\n", - "from flash.core.data import download_data\n", + "from flash.data.utils import download_data\n", "from flash.text import TextClassificationData, TextClassifier" ] }, { "cell_type": "markdown", - "id": "worse-vertex", + "id": "bronze-ghost", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "affecting-stockholm", + "id": "applied-operation", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "tracked-brush", + "id": "instrumental-approval", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "packed-albert", + "id": "flush-prince", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "combined-prior", + "id": "vital-ecuador", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "missing-century", + "id": "weighted-cosmetic", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "challenging-projector", + "id": "neural-blade", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "structural-purpose", + "id": "august-family", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "knowing-least", + "id": "figured-exhaust", "metadata": { "jupyter": { "outputs_hidden": true @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "classified-skirt", + "id": "creative-reform", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "third-collins", + "id": "periodic-holocaust", "metadata": { "jupyter": { "outputs_hidden": true @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "chemical-murray", + "id": "stopped-clark", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "swedish-result", + "id": "turned-harris", "metadata": { "jupyter": { "outputs_hidden": true @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "obvious-bench", + "id": "rotary-account", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "vertical-vietnam", + "id": "protective-panic", "metadata": {}, "source": [ "# Predicting" @@ -245,7 +245,7 @@ }, { "cell_type": "markdown", - "id": "beautiful-treasury", + "id": "precious-casino", "metadata": {}, "source": [ "### 1. Load the model from a checkpoint" @@ -254,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "animal-lloyd", + "id": "eligible-coordination", "metadata": {}, "outputs": [], "source": [ @@ -263,7 +263,7 @@ }, { "cell_type": "markdown", - "id": "configured-discussion", + "id": "worst-consumer", "metadata": {}, "source": [ "### 2a. Classify a few sentences! How was the movie?" @@ -272,7 +272,7 @@ { "cell_type": "code", "execution_count": null, - "id": "broke-barbados", + "id": "distinct-tragedy", "metadata": {}, "outputs": [], "source": [ @@ -288,7 +288,7 @@ }, { "cell_type": "markdown", - "id": "downtown-breathing", + "id": "limited-culture", "metadata": {}, "source": [ "### 2b. Or generate predictions from a sheet file!" @@ -297,7 +297,7 @@ { "cell_type": "code", "execution_count": null, - "id": "surprising-possible", + "id": "persistent-formula", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "printable-barrier", + "id": "other-grain", "metadata": {}, "source": [ "\n", @@ -367,7 +367,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/requirements.txt b/requirements.txt index 9856ed1c7e..53e6b9d706 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -pytorch_lightning # placeholder -git+https://github.com/PyTorchLightning/pytorch-lightning.git -torch>=1.7 # TODO: regenerate weights with lewer PT version +torch>=1.7 # TODO: regenerate weights with lower PT version +torchmetrics +torchvision>=0.8 # TODO: lower to 0.7 after PT 1.6 +pytorch_lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning.git +lightning-bolts==0.3.2 # TODO: align with proper release PyYAML>=5.1 Pillow>=7.2 -torchmetrics>=0.2.0 -torchvision>=0.8 # lower to 0.7 after PT 1.6 transformers>=4.0 pytorch-tabnet==3.1 datasets>=1.2, <1.3 @@ -14,6 +14,6 @@ numpy # comes with 3rd-party dependency tqdm # comes with 3rd-party dependency rouge-score>=0.0.4 sentencepiece>=0.1.95 -lightning-bolts==0.3.2 # todo: we shall align with proper release filelock # comes with 3rd-party dependency pycocotools>=2.0.2 ; python_version >= "3.7" +kornia>=0.5.0 diff --git a/setup.py b/setup.py index bb6a6b8fda..0cb88a53cb 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,18 @@ #!/usr/bin/env python import os +import subprocess # Always prefer setuptools over distutils import sys from setuptools import find_packages, setup +# temporary solution until next PyTorch Lightning release +try: + import pytorch_lightning + assert pytorch_lightning.__version__ == "1.3.0dev" +except (ImportError, AssertionError): + subprocess.Popen(["pip", "install", "git+https://github.com/PyTorchLightning/pytorch-lightning.git"]) + try: from flash import info, setup_tools except ImportError: diff --git a/tests/__init__.py b/tests/__init__.py index 043f7e78cd..c64310c910 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -import urllib.request +from six.moves import urllib # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 opener = urllib.request.build_opener() diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 4a306894bf..12d163a50b 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -16,7 +16,6 @@ import torch from flash import DataModule -from flash.core.data import DataPipeline # ======== Mock functions ======== @@ -55,38 +54,8 @@ def test_dataloaders(): def test_cpu_count_none(): train_ds = DummyDataset() - # with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"): dm = DataModule(train_ds, num_workers=None) if platform.system() == "Darwin": assert dm.num_workers == 0 else: assert dm.num_workers > 0 - - -def test_pipeline(): - - class BoringPipeline(DataPipeline): - called = [] - - def before_collate(self, _): - self.called.append("before_collate") - - def collate(self, _): - self.called.append("collate") - - def after_collate(self, _): - self.called.append("after_collate") - - def before_uncollate(self, _): - self.called.append("before_uncollate") - - def uncollate(self, _): - self.called.append("uncollate") - - def after_uncollate(self, _): - self.called.append("after_uncollate") - - pipeline = BoringPipeline() - pipeline.collate_fn(None) - pipeline.uncollate_fn(torch.tensor(0)) - assert pipeline.called == [f"{b}{a}collate" for a in ("", "un") for b in ("before_", "", "after_")] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6a291b3879..35da8590f3 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -26,7 +26,7 @@ from flash import ClassificationTask from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier -from flash.vision import ImageClassifier +from flash.vision import ImageClassificationData, ImageClassifier # ======== Mock functions ======== @@ -51,7 +51,7 @@ def __getitem__(self, index: int) -> Tensor: @pytest.mark.parametrize("metrics", [None, pl.metrics.Accuracy(), {"accuracy": pl.metrics.Accuracy()}]) def test_classificationtask_train(tmpdir: str, metrics: Any): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss, metrics=metrics) @@ -63,7 +63,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): def test_classificationtask_task_predict(): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model) ds = DummyDataset() expected = list(range(10)) @@ -82,35 +82,40 @@ def test_classification_task_predict_folder_path(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() + def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype="uint8")) + _rand_image().save(train_dir / "1.png") _rand_image().save(train_dir / "2.png") + datamodule = ImageClassificationData.from_folders(predict_folder=train_dir) + task = ImageClassifier(num_classes=10) - predictions = task.predict(str(train_dir)) + predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline) assert len(predictions) == 2 -@pytest.mark.skip("Requires DataPipeline update") # TODO def test_classification_task_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) ds = PredictDummyDataset() batch_size = 3 - predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, collate_fn=task.data_pipeline.collate_fn) + predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size) trainer = pl.Trainer(default_root_dir=tmpdir) predictions = trainer.predict(task, predict_dl) - assert len(predictions) == 3 - for pred in predictions: - assert pred.shape == (3, 10) + assert len(predictions) == len(ds) // batch_size + for batch_pred in predictions: + assert len(batch_pred) == batch_size + assert all(y < 10 for y in batch_pred) def test_task_datapipeline_save(tmpdir): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) # to check later - task.data_pipeline.test = True + task.postprocess.test = True # generate a checkpoint trainer = pl.Trainer( @@ -127,15 +132,14 @@ def test_task_datapipeline_save(tmpdir): # load from file task = ClassificationTask.load_from_checkpoint(path, model=model) - assert task.data_pipeline.test + assert task.postprocess.test -@pytest.mark.skipif(reason="Weights are using the new API") @pytest.mark.parametrize( ["cls", "filename"], [ (ImageClassifier, "image_classification_model.pt"), - (TabularClassifier, "tabnet_classification_model.pt"), + (TabularClassifier, "tabular_classification_model.pt"), (TextClassifier, "text_classification_model.pt"), (SummarizationTask, "summarization_model_xsum.pt"), # (TranslationTask, "translation_model_en_ro.pt"), todo: reduce model size or create CI friendly file size @@ -146,7 +150,3 @@ def test_model_download(tmpdir, cls, filename): with tmpdir.as_cwd(): task = cls.load_from_checkpoint(url + filename) assert isinstance(task, cls) - - -def _rand_image(): - return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index ea08e2c806..226a69a06f 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from flash import utils -from flash.core.data import download_data +from flash.core import utils as core_utils +from flash.data.utils import download_data # ======== Mock functions ======== @@ -35,20 +35,20 @@ def b(): def test_get_callable_name(): - assert utils.get_callable_name(A()) == "a" - assert utils.get_callable_name(b) == "b" - assert utils.get_callable_name(c) == "" + assert core_utils.get_callable_name(A()) == "a" + assert core_utils.get_callable_name(b) == "b" + assert core_utils.get_callable_name(c) == "" def test_get_callable_dict(): - d = utils.get_callable_dict(A()) + d = core_utils.get_callable_dict(A()) assert type(d["a"]) == A - d = utils.get_callable_dict([A(), b]) + d = core_utils.get_callable_dict([A(), b]) assert type(d["a"]) == A assert d["b"] == b - d = utils.get_callable_dict({"one": A(), "two": b, "three": c}) + d = core_utils.get_callable_dict({"one": A(), "two": b, "three": c}) assert type(d["one"]) == A assert d["two"] == b assert d["three"] == c diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index f2ffd880ab..273b5aa870 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -173,10 +173,7 @@ def test_preprocessing_data_pipeline_no_running_stage(with_dataset): dataset = pipe._generate_auto_dataset(range(10), running_stage=None) - with pytest.raises( - RuntimeError, - match='Names for LoadSample and LoadData could not be inferred. Consider setting the RunningStage' - ): + with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): for idx in range(len(dataset)): dataset[idx] diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 6610c78d81..68d10cf151 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -58,14 +58,13 @@ class CustomDataModule(DataModule): def __init__(self): super().__init__( - train_ds=DummyDataset(), - valid_ds=DummyDataset(), - test_ds=DummyDataset(), - predict_ds=DummyDataset(), + train_dataset=DummyDataset(), + valid_dataset=DummyDataset(), + test_dataset=DummyDataset(), + predict_dataset=DummyDataset(), ) -@pytest.mark.skipif(reason="Still using DataPipeline Old API") @pytest.mark.parametrize("use_preprocess", [False, True]) @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): @@ -85,7 +84,7 @@ class SubPostprocess(Postprocess): model = CustomModel(Postprocess()) model.data_pipeline = data_pipeline - assert isinstance(model._preprocess, Preprocess) + assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess) @@ -286,7 +285,6 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): data_pipeline.worker_preprocessor(RunningStage.PREDICTING) -@pytest.mark.skipif(reason="Still using DataPipeline Old API") def test_detach_preprocessing_from_model(tmpdir): preprocess = CustomPreprocess() @@ -334,7 +332,6 @@ def predict_per_batch_transform_on_device(self, *_, **__): pass -@pytest.mark.skipif(reason="Still using DataPipeline Old API") def test_attaching_datapipeline_to_model(tmpdir): preprocess = TestPreprocess() @@ -660,7 +657,6 @@ def val_collate(self, *_): assert not DataPipeline._is_overriden_recursive("chocolate", preprocess, Preprocess) -@pytest.mark.skipif(reason="Still using DataPipeline Old API") @mock.patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): diff --git a/tests/data/test_flash_datamodule.py b/tests/data/test_flash_datamodule.py deleted file mode 100644 index c50bd8544f..0000000000 --- a/tests/data/test_flash_datamodule.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from flash.data.data_module import DataModule - - -def test_flash_special_arguments(tmpdir): - - class CustomDataModule(DataModule): - - test = 1 - - dm = CustomDataModule() - CustomDataModule.test = 2 - assert dm.test == 2 - - class CustomDataModule2(DataModule): - - test = 1 - __flash_special_attr__ = ["test"] - - dm = CustomDataModule2() - CustomDataModule2.test = 2 - assert dm.test == 1 diff --git a/tests/data/test_serialization.py b/tests/data/test_serialization.py index 051a1ae619..1dc908e94c 100644 --- a/tests/data/test_serialization.py +++ b/tests/data/test_serialization.py @@ -38,7 +38,6 @@ def load_data(cls, data): return data -@pytest.mark.skipif(reason="Still using DataPipeline Old API") def test_serialization_data_pipeline(tmpdir): model = CustomModel() diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 9bcc4c0f06..669baee5a1 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import subprocess import sys from pathlib import Path from typing import List, Optional, Tuple +from unittest import mock import pytest @@ -49,26 +51,28 @@ def run_test(filepath): assert not code +@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "step,file", + "folder,file", [ # ("finetuning", "image_classification.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), - # ("finetuning", "text_classification.py"), + ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - # ("predict", "classify_image.py"), - # ("predict", "classify_tabular.py"), - # ("predict", "classify_text.py"), - # ("predict", "image_embedder.py"), - # ("predict", "summarize.py"), + ("predict", "image_classification.py"), + ("predict", "tabular_classification.py"), + ("predict", "text_classification.py"), + ("predict", "image_embedder.py"), + ("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] ) -def test_example(tmpdir, step, file): - run_test(str(root / "flash_examples" / step / file)) +def test_example(tmpdir, folder, file): + run_test(str(root / "flash_examples" / folder / file)) +@pytest.mark.skipif(reason="CI bug") def test_generic_example(tmpdir): run_test(str(root / "flash_examples" / "generic_task.py")) diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 7ddbfeb5ea..65e04699a9 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -85,11 +85,11 @@ def test_tabular_data(tmpdir): train_df = TEST_DF_1.copy() valid_df = TEST_DF_2.copy() test_df = TEST_DF_2.copy() - dm = TabularData( + dm = TabularData.from_df( train_df, - categorical_input=["category"], - numerical_input=["scalar_b", "scalar_b"], - target="label", + categorical_cols=["category"], + numerical_cols=["scalar_b", "scalar_b"], + target_col="label", valid_df=valid_df, test_df=test_df, num_workers=0, @@ -110,11 +110,11 @@ def test_categorical_target(tmpdir): # change int label to string df["label"] = df["label"].astype(str) - dm = TabularData( + dm = TabularData.from_df( train_df, - categorical_input=["category"], - numerical_input=["scalar_b", "scalar_b"], - target="label", + categorical_cols=["category"], + numerical_cols=["scalar_b", "scalar_b"], + target_col="label", valid_df=valid_df, test_df=test_df, num_workers=0, @@ -133,9 +133,9 @@ def test_from_df(tmpdir): test_df = TEST_DF_2.copy() dm = TabularData.from_df( train_df, - categorical_input=["category"], - numerical_input=["scalar_b", "scalar_b"], - target="label", + categorical_cols=["category"], + numerical_cols=["scalar_b", "scalar_b"], + target_col="label", valid_df=valid_df, test_df=test_df, num_workers=0, @@ -156,10 +156,10 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(test_csv) dm = TabularData.from_csv( - train_csv, - categorical_input=["category"], - numerical_input=["scalar_b", "scalar_b"], - target="label", + train_csv=train_csv, + categorical_cols=["category"], + numerical_cols=["scalar_b", "scalar_b"], + target_col="label", valid_csv=valid_csv, test_csv=test_csv, num_workers=0, @@ -176,5 +176,5 @@ def test_empty_inputs(): train_df = TEST_DF_1.copy() with pytest.raises(RuntimeError): TabularData.from_df( - train_df, categorical_input=None, numerical_input=None, target="label", num_workers=0, batch_size=1 + train_df, numerical_cols=None, categorical_cols=None, target_col="label", num_workers=0, batch_size=1 ) diff --git a/tests/tabular/data/test_dataset.py b/tests/tabular/data/test_dataset.py index 6ecf70b664..0039473ffa 100644 --- a/tests/tabular/data/test_dataset.py +++ b/tests/tabular/data/test_dataset.py @@ -43,7 +43,7 @@ def test_pandas(): cat_cols=["category"], num_cols=["scalar_a", "scalar_b"], target_col="label", - regression=False, + is_regression=False, ) assert len(ds) == 6 (cat, num), target = ds[0] @@ -59,7 +59,7 @@ def test_pandas_no_cat(): cat_cols=[], num_cols=["scalar_a", "scalar_b"], target_col="label", - regression=False, + is_regression=False, ) assert len(ds) == 6 (cat, num), target = ds[0] @@ -75,7 +75,7 @@ def test_pandas_no_num(): cat_cols=["category"], num_cols=[], target_col="label", - regression=False, + is_regression=False, ) assert len(ds) == 6 (cat, num), target = ds[0] diff --git a/tests/tabular/test_data_model_integration.py b/tests/tabular/test_data_model_integration.py index 223888cb6d..6963c561e9 100644 --- a/tests/tabular/test_data_model_integration.py +++ b/tests/tabular/test_data_model_integration.py @@ -33,9 +33,9 @@ def test_classification(tmpdir): test_df = TEST_DF_1.copy() data = TabularData.from_df( train_df, - categorical_input=["category"], - numerical_input=["scalar_a", "scalar_b"], - target="label", + categorical_cols=["category"], + numerical_cols=["scalar_a", "scalar_b"], + target_col="label", valid_df=valid_df, test_df=test_df, num_workers=0, diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index de72d94abe..499f32627c 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -17,60 +17,82 @@ import numpy as np import torch from PIL import Image -from torchvision import transforms as T from flash.data.data_utils import labels_from_categorical_csv from flash.vision import ImageClassificationData def _dummy_image_loader(_): - return torch.rand(3, 64, 64) + return torch.rand(3, 196, 196) def _rand_image(): - return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) def test_from_filepaths(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + img_data = ImageClassificationData.from_filepaths( - train_filepaths=["a", "b"], + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_transform=None, train_labels=[0, 1], - train_transform=lambda x: x, # make sure transform works - loader=_dummy_image_loader, batch_size=1, num_workers=0, ) data = next(iter(img_data.train_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None + (tmpdir / "c").mkdir() + (tmpdir / "d").mkdir() + _rand_image().save(tmpdir / "c" / "c_1.png") + _rand_image().save(tmpdir / "c" / "c_2.png") + _rand_image().save(tmpdir / "d" / "d_1.png") + _rand_image().save(tmpdir / "d" / "d_2.png") + + (tmpdir / "e").mkdir() + (tmpdir / "f").mkdir() + _rand_image().save(tmpdir / "e" / "e_1.png") + _rand_image().save(tmpdir / "e" / "e_2.png") + _rand_image().save(tmpdir / "f" / "f_1.png") + _rand_image().save(tmpdir / "f" / "f_2.png") + img_data = ImageClassificationData.from_filepaths( - train_filepaths=["a", "b"], + train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], train_transform=None, - valid_filepaths=["c", "d"], + valid_filepaths=[tmpdir / "c", tmpdir / "d"], valid_labels=[0, 1], valid_transform=None, - test_filepaths=["e", "f"], + test_transform=None, + test_filepaths=[tmpdir / "e", tmpdir / "f"], test_labels=[0, 1], - loader=_dummy_image_loader, batch_size=1, num_workers=0, ) data = next(iter(img_data.val_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) data = next(iter(img_data.test_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) @@ -123,15 +145,17 @@ def index_col_collate_fn(x): test_labels = labels_from_categorical_csv( test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn ) - data = ImageClassificationData.from_filepaths( batch_size=2, + train_transform=None, + valid_transform=None, + test_transform=None, train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), - train_labels=train_labels, + train_labels=train_labels.values(), valid_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'), - valid_labels=valid_labels, + valid_labels=valid_labels.values(), test_filepaths=os.path.join(tmpdir, 'some_dataset', 'test'), - test_labels=test_labels, + test_labels=test_labels.values(), ) for (x, y) in data.train_dataloader(): @@ -143,16 +167,6 @@ def index_col_collate_fn(x): for (x, y) in data.test_dataloader(): assert len(x) == 2 - data = ImageClassificationData.from_filepaths( - batch_size=2, - train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), - train_labels=train_labels, - valid_split=0.5 - ) - - for (x, y) in data.val_dataloader(): - assert len(x) == 1 - def test_from_folders(tmpdir): train_dir = Path(tmpdir / "train") @@ -171,7 +185,7 @@ def test_from_folders(tmpdir): ) data = next(iter(img_data.train_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) assert img_data.val_dataloader() is None @@ -179,9 +193,7 @@ def test_from_folders(tmpdir): img_data = ImageClassificationData.from_folders( train_dir, - train_transform=T.ToTensor(), valid_folder=train_dir, - valid_transform=T.ToTensor(), test_folder=train_dir, batch_size=1, num_workers=0, @@ -189,10 +201,10 @@ def test_from_folders(tmpdir): data = next(iter(img_data.val_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) data = next(iter(img_data.test_dataloader())) imgs, labels = data - assert imgs.shape == (1, 3, 64, 64) + assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index a468f38416..1181df70ee 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + +import numpy as np import torch +from PIL import Image from flash import Trainer from flash.vision import ImageClassificationData, ImageClassifier @@ -21,12 +25,24 @@ def _dummy_image_loader(_): return torch.rand(3, 224, 224) +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + + def test_classification(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") data = ImageClassificationData.from_filepaths( - train_filepaths=["a", "b"], + train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], - train_transform=lambda x: x, - loader=_dummy_image_loader, + train_transform={"per_sample_per_batch_transform": lambda x: x}, num_workers=0, batch_size=2, ) diff --git a/tests/vision/detection/test_data.py b/tests/vision/detection/test_data.py index bf4ba2a170..10b242dc4d 100644 --- a/tests/vision/detection/test_data.py +++ b/tests/vision/detection/test_data.py @@ -6,10 +6,9 @@ from PIL import Image from pytorch_lightning.utilities import _module_available +from flash.utils.imports import _COCO_AVAILABLE from flash.vision.detection.data import ObjectDetectionData -_COCO_AVAILABLE = _module_available("pycocotools") - def _create_dummy_coco_json(dummy_json_path): diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/vision/detection/test_data_model_integration.py index 51fcd956b4..e2c53874c8 100644 --- a/tests/vision/detection/test_data_model_integration.py +++ b/tests/vision/detection/test_data_model_integration.py @@ -18,12 +18,11 @@ from pytorch_lightning.utilities import _module_available import flash +from flash.utils.imports import _COCO_AVAILABLE from flash.vision import ObjectDetector from flash.vision.detection import ObjectDetectionData from tests.vision.detection.test_data import _create_synth_coco_dataset -_COCO_AVAILABLE = _module_available("pycocotools") - @pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") @pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "mobilenet_v2")]) @@ -41,9 +40,8 @@ def test_detection(tmpdir, model, backbone): test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") - Image.new('RGB', (1920, 1080)).save(test_image_one) - Image.new('RGB', (1920, 1080)).save(test_image_two) + Image.new('RGB', (512, 512)).save(test_image_one) + Image.new('RGB', (512, 512)).save(test_image_two) test_images = [test_image_one, test_image_two] - model.predict(test_images)