diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 6dbbcabc0e..5c45d392e1 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -59,6 +59,10 @@ jobs: python -m coverage run --source flash -m pytest flash tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30 displayName: 'Testing' + - bash: | + bash tests/special_tests.sh + displayName: 'Testing: special' + - bash: | python -m coverage report python -m coverage xml diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 34eabb6cc3..606dd9b7d2 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -136,7 +136,15 @@ jobs: - name: Install vissl if: matrix.topic[1] == 'image_extras' run: | - pip install git+https://github.com/facebookresearch/vissl.git@master + pip install git+https://github.com/facebookresearch/ClassyVision.git + pip install git+https://github.com/facebookresearch/vissl.git + + - name: Install graph test dependencies + if: matrix.topic[0] == 'graph' + run: | + pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html + pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html - name: Install dependencies run: | @@ -166,8 +174,8 @@ jobs: uses: actions/cache@v2 with: path: data # This path is specific to Ubuntu - key: lightning-flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }} - restore-keys: lightning-flash-datasets- + key: flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }} + restore-keys: flash-datasets- - name: Tests env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5166065e0e..11d4b0accf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [Unreleased] - YYYY-MM-DD + +### Added + +- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737)) + +- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682)) + +- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785)) + +### Changed + +- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) + +### Fixed + +- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792)) + + ## [0.5.0] - 2021-09-07 ### Added diff --git a/README.md b/README.md index 03596edcdb..3cee739f3e 100644 --- a/README.md +++ b/README.md @@ -3,18 +3,14 @@ -**Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning** +**Your PyTorch AI Factory** ---

- Installation • + Installation • + Flash in 3 StepsDocs • - About • - Prediction • - Finetuning • - Tasks • - General TaskContributeCommunityWebsite • @@ -25,597 +21,264 @@ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/lightning-flash)](https://pypi.org/project/lightning-flash/) [![PyPI Status](https://badge.fury.io/py/lightning-flash.svg)](https://badge.fury.io/py/lightning-flash) [![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ) -[![Discourse status](https://img.shields.io/discourse/status?server=https%3A%2F%2Fforums.pytorchlightning.ai)](https://forums.pytorchlightning.ai/) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/PytorchLightning/pytorch-lightning/blob/master/LICENSE) - [![Documentation Status](https://readthedocs.org/projects/lightning-flash/badge/?version=latest)](https://lightning-flash.readthedocs.io/en/stable/?badge=stable) ![CI testing](https://github.com/PyTorchLightning/lightning-flash/workflows/CI%20testing/badge.svg?branch=master&event=push) [![codecov](https://codecov.io/gh/PyTorchLightning/lightning-flash/branch/master/graph/badge.svg?token=oLuUr9q1vt)](https://codecov.io/gh/PyTorchLightning/lightning-flash) - - --- -__Note:__ Flash is currently being tested on real-world use cases and is in active development. Please [open an issue](https://github.com/PyTorchLightning/lightning-flash/issues/new/choose) if you find anything that isn't working as expected. - ---- - -## News - -- Jul 12: Flash Task-a-thon community sprint with 25+ community members -- Jul 1: [Lightning Flash 0.4](https://devblog.pytorchlightning.ai/lightning-flash-0-4-flash-serve-fiftyone-multi-label-text-classification-and-jit-support-97428276c06f) -- Jun 22: [Ushering in the New Age of Video Understanding with PyTorch](https://medium.com/pytorch/ushering-in-the-new-age-of-video-understanding-with-pytorch-1d85078e8015) -- May 24: [Lightning Flash 0.3](https://devblog.pytorchlightning.ai/lightning-flash-0-3-new-tasks-visualization-tools-data-pipeline-and-flash-registry-api-1e236ba9530) -- May 20: [Video Understanding with PyTorch](https://towardsdatascience.com/video-understanding-made-simple-with-pytorch-video-and-lightning-flash-c7d65583c37e) -- Feb 2: [Read our launch blogpost](https://pytorch-lightning.medium.com/introducing-lightning-flash-the-fastest-way-to-get-started-with-deep-learning-202f196b3b98) - ---- +

+ Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains + + + +
-## Installation +## Getting Started -Pip / conda +From PyPI: ```bash pip install lightning-flash ``` -
- Other installations - -Pip from source - -```bash -# with git -pip install git+https://github.com/PytorchLightning/lightning-flash.git@master - -# OR from an archive -pip install https://github.com/PyTorchLightning/lightning-flash/archive/master.zip -``` - -From source using `setuptools` -``` bash -# clone flash repository locally -git clone https://github.com/PyTorchLightning/lightning-flash.git -cd lightning-flash -# install in editable mode -pip install -e . -``` - -In case you want to use the extra packages from a specific domain (image, video, text, ...) -```bash -pip install "lightning-flash[image]" -``` -See [Installation](https://lightning-flash.readthedocs.io/en/latest/installation.html) for more options. -
- ---- - -## What is Flash -Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. It is focused on: - -- Predictions -- Finetuning -- Task-based training +See [our installation guide](https://lightning-flash.readthedocs.io/en/latest/installation.html) for more options. -It is built for data scientists, machine learning practitioners, and applied researchers. +## Flash in 3 Steps +### Step 1. Load your data -## Scalability -Flash is built on top of [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) (by the Lightning team), which is a thin organizational layer on top of PyTorch. If you know PyTorch, you know PyTorch Lightning and Flash already! +All data loading in Flash is performed via a `from_*` classmethod on a `DataModule`. +Which `DataModule` to use and which `from_*` methods are available depends on the task you want to perform. +For example, for image segmentation where your data is stored in folders, you would use the [`from_folders` method of the `SemanticSegmentationData` class](https://lightning-flash.readthedocs.io/en/latest/reference/semantic_segmentation.html#from-folders): -As a result, Flash can scale up across any hardware (GPUs, TPUS) with zero changes to your code. It also has the best practices -in AI research embedded into each task so you don't have to be a deep learning PhD to leverage its power :) +```py +from flash.image import SemanticSegmentationData -### Predictions - -```python -from flash.text import TranslationTask - -# 1. Load finetuned task -model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") - -# 2. Translate a few sentences! -predictions = model.predict( - [ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", - ] +dm = SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", + val_split=0.1, + image_size=(256, 256), + num_classes=21, ) -print(predictions) -``` - -### Serving - -`Serve` is a framework agnostic serving engine ! [Learn more](https://lightning-flash.readthedocs.io/en/latest/general/serve.html#) and [check out our examples](flash_examples/serve). - -```python -from flash.text import TextClassifier -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") -model.serve() ``` -Credits to [@rlizzo](https://github.com/rlizzo), [@hhsecond](https://github.com/hhsecond), [@lantiga](https://github.com/lantiga), [@luiscape](https://github.com/luiscape) for building Flash Serve Engine. - -### Finetuning - -First, finetune: +### Step 2: Configure your model -```python -import flash -from flash.core.data.utils import download_data -from flash.image import ImageClassificationData, ImageClassifier +Our tasks come loaded with pre-trained backbones and (where applicable) heads. +You can view the available backbones to use with your task using [`available_backbones`](https://lightning-flash.readthedocs.io/en/latest/general/backbones.html). +Once you've chosen, create the model: -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") +```py +from flash.image import SemanticSegmentation -# 2. Load the data -datamodule = ImageClassificationData.from_folders( - train_folder="data/hymenoptera_data/train/", - val_folder="data/hymenoptera_data/val/", - test_folder="data/hymenoptera_data/test/", -) +print(SemanticSegmentation.available_heads()) +# ['deeplabv3', 'deeplabv3plus', 'fpn', ..., 'unetplusplus'] -# 3. Build the model -model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18") +print(SemanticSegmentation.available_backbones('fpn')) +# ['densenet121', ..., 'xception'] # + 113 models -# 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1) +print(SemanticSegmentation.available_pretrained_weights('efficientnet-b0')) +# ['imagenet', 'advprop'] -# 5. Finetune the model -trainer.finetune(model, datamodule=datamodule, strategy="freeze") - -# 6. Save it! -trainer.save_checkpoint("image_classification_model.pt") +model = SemanticSegmentation( + head="fpn", backbone='efficientnet-b0', pretrained="advprop", num_classes=dm.num_classes) ``` -Then use the finetuned model: - -```python -from flash.image import ImageClassifier +### Step 3: Finetune! -# load the finetuned model -classifier = ImageClassifier.load_from_checkpoint("image_classification_model.pt") +```py +from flash import Trainer -# predict! -predictions = classifier.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") -print(predictions) +trainer = Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") +trainer.save_checkpoint("semantic_segmentation_model.pt") ``` --- -## Tasks -Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods. +## PyTorch Recipes -### Example 1: Image embedding -Flash has an [Image Embedder task](https://lightning-flash.readthedocs.io/en/latest/reference/image_embedder.html) to encode an image into a vector of image features which can be used for anything like clustering, similarity search or classification. +### Make predictions with Flash! -
- View example +Serve in just 2 lines. -```python -from flash.core.data.utils import download_data -from flash.image import ImageEmbedder +```py +from flash.image import SemanticSegmentation -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") - -# 2. Create an ImageEmbedder with resnet50 trained on imagenet. -embedder = ImageEmbedder(backbone="resnet50") - -# 3. Generate an embedding from an image path. -embeddings = embedder.predict("data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg") - -# 4. Print embeddings shape -print(embeddings[0].shape) -``` - -
- -### Example 2: Text Summarization -Flash has a [Summarization task](https://lightning-flash.readthedocs.io/en/latest/reference/summarization.html) to sum up text from a larger article into a short description. - -
- View example - -```python -import flash -import torch -from flash.core.data.utils import download_data -from flash.text import SummarizationData, SummarizationTask - -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") - -# 2. Load the data -datamodule = SummarizationData.from_csv( - "input", - "target", - train_file="data/xsum/train.csv", - val_file="data/xsum/valid.csv", - test_file="data/xsum/test.csv", -) - -# 3. Build the model -model = SummarizationTask() - -# 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), precision=16) - -# 5. Fine-tune the model -trainer.finetune(model, datamodule=datamodule) - -# 6. Test model -trainer.test() -``` -To run the example: -```bash -python flash_examples/finetuning/summarization.py +model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") +model.serve() ``` -
- -### Example 3: Tabular Classification - -Flash has a [Tabular Classification task](https://lightning-flash.readthedocs.io/en/latest/reference/tabular_classification.html) to tackle any tabular classification problem. - -
- View example +or make predictions from raw data directly. -To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. - -```python -from torchmetrics.classification import Accuracy, Precision, Recall -import flash -from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularClassificationData - -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") - -# 2. Load the data -datamodule = TabularClassificationData.from_csv( - ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - "Fare", - target_fields="Survived", - train_file="./data/titanic/titanic.csv", - test_file="./data/titanic/test.csv", - val_split=0.25, -) - -# 3. Build the model -model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()]) - -# 4. Create the trainer. Run 10 times on data -trainer = flash.Trainer(max_epochs=10) - -# 5. Train the model -trainer.fit(model, datamodule=datamodule) - -# 6. Test model -trainer.test() - -# 7. Predict! -predictions = model.predict("data/titanic/titanic.csv") -print(predictions) -``` -To run the example: -```bash -python flash_examples/finetuning/tabular_data.py +```py +predictions = model.predict(["data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png"]) ``` -
- -### Example 4: Object Detection - -Flash has an [Object Detection task](https://lightning-flash.readthedocs.io/en/latest/reference/object_detection.html) to identify and locate objects in images. - -
- View example - -To illustrate, say we want to build a model on a tiny coco dataset. - -```python -import flash -from flash.core.data.utils import download_data -from flash.image import ObjectDetectionData, ObjectDetector - -# 1. Download the data -# Dataset Credit: https://www.kaggle.com/ultralytics/coco128 -download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") - -# 2. Load the Data -datamodule = ObjectDetectionData.from_coco( - train_folder="data/coco128/images/train2017/", - train_ann_file="data/coco128/annotations/instances_train2017.json", - batch_size=2, -) - -# 3. Build the model -model = ObjectDetector(num_classes=datamodule.num_classes) - -# 4. Create the trainer. Run twice on data -trainer = flash.Trainer(max_epochs=3) +or make predictions with 2 GPUs. -# 5. Finetune the model -trainer.fit(model, datamodule=datamodule) - -# 6. Save it! -trainer.save_checkpoint("object_detection_model.pt") -``` -To run the example: -```bash -python flash_examples/finetuning/object_detection.py +```py +trainer = Trainer(accelerator='ddp', gpus=2) +dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB") +predictions = trainer.predict(model, dm) ``` -
- -### Example 5: Video Classification with PyTorchVideo - -Flash has a [Video Classification task](https://lightning-flash.readthedocs.io/en/latest/reference/video_classification.html) to classify videos using [PyTorchVideo](https://pytorchvideo.org/). - -
- View example - -To illustrate, say we want to build a model to classify the kinetics data set. - -```python -import os -from torch.utils.data.sampler import RandomSampler -import flash -from flash.core.data.utils import download_data -from flash.video import VideoClassificationData, VideoClassifier - -# 1. Download a video clip dataset. Find more datasets at https://pytorchvideo.readthedocs.io/en/latest/data.html -download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip") - -# 2. Load the Data -datamodule = VideoClassificationData.from_folders( - train_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/train"), - val_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/val"), - predict_folder=os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict"), - batch_size=8, - clip_sampler="uniform", - clip_duration=1, - video_sampler=RandomSampler, - decode_audio=False, - num_workers=8, +### Flash Training Strategies + +Training strategies are PyTorch SOTA Training Recipes which can be utilized with a given task. + + +Check out this [example](https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py) where the `ImageClassifier` supports 4 [Meta Learning Algorithms](https://lilianweng.github.io/lil-log/2018/11/30/meta-learning.html) from [Learn2Learn](https://github.com/learnables/learn2learn). +This is particularly useful if you use this model in production and want to make sure the model adapts quickly to its new environment with minimal labelled data. + +```py +model = ImageClassifier( + backbone="resnet18", + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + training_strategy="prototypicalnetworks", + training_strategy_kwargs={ + "epoch_length": 10 * 16, + "meta_batch_size": 4, + "num_tasks": 200, + "test_num_tasks": 2000, + "ways": datamodule.num_classes, + "shots": 1, + "test_ways": 5, + "test_shots": 1, + "test_queries": 15, + }, ) - -# 3. Build the model -model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False) - -# 4. Create the trainer -trainer = flash.Trainer(max_epochs=3) - -# 5. Finetune the model -trainer.finetune(model, datamodule=datamodule) - -# 6. Save it! -trainer.save_checkpoint("video_classification.pt") -``` -To run the example: -```bash -python flash_examples/finetuning/video_classification.py ``` -
+In detail, the following methods are currently implemented: -### Example 6: Semantic Segmentation +* **[prototypicalnetworks](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_protonet.py)** : from Snell *et al.* 2017, [Prototypical Networks for Few-shot Learning](https://arxiv.org/abs/1703.05175) +* **[maml](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_maml.py)** : from Finn *et al.* 2017, [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) +* **[metaoptnet](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_metaoptnet.py)** : from Lee *et al.* 2019, [Meta-Learning with Differentiable Convex Optimization](https://arxiv.org/abs/1904.03758) +* **[anil](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_anil.py)** : from Raghu *et al.* 2020, [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML](https://arxiv.org/abs/1909.09157) -Flash has a [Semantic Segmentation task](https://lightning-flash.readthedocs.io/en/latest/reference/semantic_segmentation.html) for segmentation of images. +### Flash Transforms -
- View example -To illustrate, say we want to finetune a model on [this data from the Lyft Udacity Challenge](https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge). +Flash includes some simple augmentations for each task by default, however, you will often want to override these and control your own augmentation recipe. +To this end, Flash supports custom transformations backed by our powerful data pipeline. +The transform requires to be passed as a dictionary of transforms where the keys are the [hook's name](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.process.Preprocess.html?highlight=Preprocess). +This enable transforms to be applied per sample or per batch either on or off device. +It is important to note that data are being processed as a dictionary for all tasks (typically containing `input`, `target`, and `metadata`), +Therefore, you can use [`ApplyToKeys`](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.transforms.ApplyToKeys.html#flash.core.data.transforms.ApplyToKeys) utility to apply the transform to a specific key. +Complex transforms (like MixUp) can then be implemented with ease. -```python -import flash -from flash.core.data.utils import download_data -from flash.image import SemanticSegmentation, SemanticSegmentationData +The example also uses our [`merge_transforms`](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.transforms.merge_transforms.html#flash.core.data.transforms.merge_transforms) utility to merge our custom augmentations with the default transforms for images (which handle resizing and converting to a tensor). -# 1. Download the Data -download_data( - "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" -) -# 2. Load the Data -datamodule = SemanticSegmentationData.from_folders( - train_folder="data/CameraRGB", - train_target_folder="data/CameraSeg", - batch_size=4, - val_split=0.3, - image_size=(200, 200), - num_classes=21, -) +```py +import torch +from typing import Any +import numpy as np +import albumentations +from torchvision import transforms as T +from flash.core.data.transforms import ApplyToKeys, merge_transforms +from flash.image import ImageClassificationData +from flash.image.classification.transforms import default_transforms, AlbumentationsAdapter + +def mixup(batch, alpha=1.0): + images = batch["input"] + targets = batch["target"].float().unsqueeze(1) + + lam = np.random.beta(alpha, alpha) + perm = torch.randperm(images.size(0)) + + batch["input"] = images * lam + images[perm] * (1 - lam) + batch["target"] = targets * lam + targets[perm] * (1 - lam) + return batch + +train_transform = { + # applied only on images as ApplyToKeys is used with `input` + "post_tensor_transform": ApplyToKeys( + "input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), + + # applied to the entire dictionary as `ApplyToKeys` isn't used. + # this would be applied on GPUS ! + "per_batch_transform_on_device": mixup, + + # this would be applied on CPUS within the DataLoader workers ! + # "per_batch_transform": mixup +} +# merge the default transform for this task with new one. +train_transform = merge_transforms(default_transforms((256, 256)), train_transform) -# 3. Build the model -model = SemanticSegmentation( - backbone="torchvision/fcn_resnet50", - num_classes=datamodule.num_classes, +datamodule = ImageClassificationData.from_folders( + train_folder = "data/train", + train_transform=train_transform, ) -# 4. Create the trainer -trainer = flash.Trainer(max_epochs=3) - -# 5. Finetune the model -trainer.finetune(model, datamodule=datamodule) - -# 6. Save it! -trainer.save_checkpoint("semantic_segmentation_model.pt") -``` -To run the example: -```bash -python flash_examples/finetuning/semantic_segmentation.py ``` -
- -### Example 7: Style Transfer with pystiche - -Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [pystiche](https://pystiche.org). +## Flash Zero - PyTorch Recipes from the Command Line! -
- View example - -To illustrate, say we want to train an NST model to transfer the style from the paint demo image to the COCO data set. - -```python -import pystiche.demo -import flash -from flash.core.data.utils import download_data -from flash.image.style_transfer import StyleTransfer, StyleTransferData - -# 1. Download the Data -download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") - -# 2. Load the Data -datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) - -# 3. Load the style image -style_image = pystiche.demo.images()["paint"].read(size=256) - -# 4. Build the model -model = StyleTransfer(style_image) - -# 5. Create the trainer -trainer = flash.Trainer(max_epochs=2) - -# 6. Train the model -trainer.fit(model, datamodule=datamodule) - -# 7. Save it! -trainer.save_checkpoint("style_transfer_model.pt") -``` -To run the example: -```bash -python flash_examples/finetuning/style_transfer.py -``` - -
- -## A general task -Flash comes prebuilt with a task to handle a huge portion of deep learning problems. - -```python -import flash -from torch import nn, optim -from torch.utils.data import DataLoader, random_split -from torchvision import transforms, datasets - -# model -model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10)) +
+ +
-# data -dataset = datasets.MNIST("./data_folder", download=True, transform=transforms.ToTensor()) -train, val = random_split(dataset, [55000, 5000]) +Flash Zero is a zero-code machine learning platform built +directly into lightning-flash +using the [`Lightning CLI`](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html). -# task -classifier = flash.Task(model, loss_fn=nn.functional.cross_entropy, optimizer=optim.Adam) +To get started and view the available tasks, run: -# train -flash.Trainer().fit(classifier, DataLoader(train), DataLoader(val)) +```py + flash --help ``` -## Infinitely customizable - -Tasks can be built in just a few minutes because Flash is built on top of PyTorch Lightning LightningModules, which -are infinitely extensible and let you train across GPUs, TPUs etc without doing any code changes. +For example, to train an image classifier for 10 epochs with a `resnet50` backbone on 2 GPUs using your own data, you can do: -```python -import torch -import torch.nn.functional as F -from torchmetrics import Accuracy -from typing import Callable, Mapping, Sequence, Type, Union -from flash.core.classification import ClassificationTask - - -class LinearClassifier(ClassificationTask): - def __init__( - self, - num_inputs, - num_classes, - loss_fn: Callable = F.cross_entropy, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()], - learning_rate: float = 1e-3, - ): - super().__init__( - model=None, - loss_fn=loss_fn, - optimizer=optimizer, - metrics=metrics, - learning_rate=learning_rate, - ) - self.save_hyperparameters() - - self.linear = torch.nn.Linear(num_inputs, num_classes) - - def forward(self, x): - return self.linear(x) - - -classifier = LinearClassifier(128, 10) -... +```py + flash image_classification --trainer.max_epochs 10 --trainer.gpus 2 --model.backbone resnet50 from_folders --train_folder {PATH_TO_DATA} ``` -When you reach the limits of the flexibility provided by Flash, then seamlessly transition to PyTorch Lightning which -gives you the most flexibility because it is simply organized PyTorch. - -## Visualization - -Predictions from image and video tasks can be visualized through an [integration with FiftyOne](https://lightning-flash.readthedocs.io/en/latest/integrations/fiftyone.html), allowing you to better understand and analyze how your model is performing. - -```python -from flash.core.data.utils import download_data -from flash.core.integrations.fiftyone import visualize -from flash.image import ObjectDetector -from flash.image.detection.serialization import FiftyOneDetectionLabels +--- -# 1. Download the data -# Dataset Credit: https://www.kaggle.com/ultralytics/coco128 -download_data( - "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", - "data/", -) +## News -# 2. Load the model from a checkpoint and use the FiftyOne serializer -model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/object_detection_model.pt") -model.serializer = FiftyOneDetectionLabels() +- Sept 9: [Lightning Flash 0.5](https://devblog.pytorchlightning.ai/flash-0-5-your-pytorch-ai-factory-81b172ff0d76) +- Jul 12: Flash Task-a-thon community sprint with 25+ community members +- Jul 1: [Lightning Flash 0.4](https://devblog.pytorchlightning.ai/lightning-flash-0-4-flash-serve-fiftyone-multi-label-text-classification-and-jit-support-97428276c06f) +- Jun 22: [Ushering in the New Age of Video Understanding with PyTorch](https://medium.com/pytorch/ushering-in-the-new-age-of-video-understanding-with-pytorch-1d85078e8015) +- May 24: [Lightning Flash 0.3](https://devblog.pytorchlightning.ai/lightning-flash-0-3-new-tasks-visualization-tools-data-pipeline-and-flash-registry-api-1e236ba9530) +- May 20: [Video Understanding with PyTorch](https://towardsdatascience.com/video-understanding-made-simple-with-pytorch-video-and-lightning-flash-c7d65583c37e) +- Feb 2: [Read our launch blogpost](https://pytorch-lightning.medium.com/introducing-lightning-flash-the-fastest-way-to-get-started-with-deep-learning-202f196b3b98) -# 3. Detect the object on the images -filepaths = [ - "data/coco128/images/train2017/000000000025.jpg", - "data/coco128/images/train2017/000000000520.jpg", - "data/coco128/images/train2017/000000000532.jpg", -] -predictions = model.predict(filepaths) +__Note:__ Flash is currently being tested on real-world use cases and is in active development. Please [open an issue](https://github.com/PyTorchLightning/lightning-flash/issues/new/choose) if you find anything that isn't working as expected. -# 4. Visualize predictions in FiftyOne App -session = visualize(predictions, filepaths=filepaths) -``` +--- ## Contribute! The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks! Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ) and/or read our [CONTRIBUTING](https://github.com/PyTorchLightning/lightning-flash/blob/master/.github/CONTRIBUTING.md) guidelines to get help becoming a contributor! +--- + ## Community Flash is maintained by our [core contributors](https://lightning-flash.readthedocs.io/en/latest/governance.html). For help or questions, join our huge community on [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)! +--- + ## Citations -We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffe, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors. +We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffe, Theano, Keras, PyTorch, torchbearer, and [fast.ai](https://arxiv.org/abs/2002.04688). When/if additional papers are written about this, we’ll be happy to cite these frameworks and the corresponding authors. -Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts). +Flash leverages models from many different frameworks in order to cover such a wide range of domains and tasks. The full list of providers can be found in [our documentation](https://lightning-flash.readthedocs.io/en/latest/integrations/providers.html). + +--- ## License -Please observe the Apache 2.0 license that is listed in this repository. In addition -the Lightning framework is Patent Pending. +Please observe the Apache 2.0 license that is listed in this repository. diff --git a/flash/core/integrations/vissl/__init__.py b/config.yaml similarity index 100% rename from flash/core/integrations/vissl/__init__.py rename to config.yaml diff --git a/docs/source/_static/images/flash_zero.gif b/docs/source/_static/images/flash_zero.gif new file mode 100644 index 0000000000..36828a520d Binary files /dev/null and b/docs/source/_static/images/flash_zero.gif differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index 2c3e330bbf..be810c5d6e 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1 +1 @@ - + diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst deleted file mode 100644 index 0bd374deea..0000000000 --- a/docs/source/custom_task.rst +++ /dev/null @@ -1,323 +0,0 @@ -Tutorial: Creating a Custom Task -================================ - -In this tutorial we will go over the process of creating a custom :class:`~flash.core.model.Task`, -along with a custom :class:`~flash.core.data.data_module.DataModule`. - -.. note:: This tutorial is only intended to help you create a small custom task for a personal project. If you want a more detailed guide, have a look at our :ref:`guide on contributing a task to flash. ` - -The tutorial objective is to create a ``RegressionTask`` to learn to predict if someone has ``diabetes`` or not. -We will use ``scikit-learn`` `Diabetes dataset `__. -which is stored as numpy arrays. - -.. note:: - - Find the complete tutorial example at - `flash_examples/custom_task.py `_. - - -1. Imports ----------- - -We first import everything we're going to use and set the random seed using :func:`~pytorch_lightning.utilities.seed.seed_everything`. - -.. testcode:: custom_task - - from typing import Any, Callable, Dict, List, Optional, Tuple - - import numpy as np - import torch - from pytorch_lightning import seed_everything - from sklearn import datasets - from torch import nn, Tensor - - import flash - from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources - from flash.core.data.process import Preprocess - from flash.core.data.transforms import ApplyToKeys - - # set the random seeds. - seed_everything(42) - - ND = np.ndarray - - -2. The Task: Linear regression -------------------------------- - -Here we create a basic linear regression task by subclassing :class:`~flash.core.model.Task`. For the majority of tasks, -you will likely need to override the ``__init__``, ``forward``, and the ``{train,val,test,predict}_step`` methods. The -``__init__`` should be overridden to configure the model and any additional arguments to be passed to the base -:class:`~flash.core.model.Task`. ``forward`` may need to be overridden to apply the model forward pass to the inputs. -It's best practice in flash for the data to be provide as a dictionary which maps string keys to their values. The -``{train,val,test,predict}_step`` methods need to be overridden to extract the data from the input dictionary. - -.. testcode:: custom_task - - class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.2, metrics=None): - # what kind of model do we want? - model = torch.nn.Linear(num_inputs, 1) - - # what loss function do we want? - loss_fn = torch.nn.functional.mse_loss - - # what optimizer to do we want? - optimizer = torch.optim.Adam - - super().__init__( - model=model, - loss_fn=loss_fn, - optimizer=optimizer, - metrics=metrics, - learning_rate=learning_rate, - ) - - def training_step(self, batch: Any, batch_idx: int) -> Any: - return super().training_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def validation_step(self, batch: Any, batch_idx: int) -> None: - return super().validation_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def test_step(self, batch: Any, batch_idx: int) -> None: - return super().test_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return super().predict_step( - batch[DefaultDataKeys.INPUT], - batch_idx, - dataloader_idx, - ) - - def forward(self, x): - # we don't actually need to override this method for this example - return self.model(x) - -.. note:: - - Lightning Flash provides registries. - Registries are Flash internal key-value database to store a mapping between a name and a function. - In simple words, they are just advanced dictionary storing a function from a key string. - They are useful to store list of backbones and make them available for a :class:`~flash.core.model.Task`. - Check out :ref:`registry` to learn more. - - -Where is the training step? -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Most models can be trained simply by passing the output of ``forward`` to the supplied ``loss_fn``, and then passing the -resulting loss to the supplied ``optimizer``. If you need a more custom configuration, you can override ``step`` (which -is called for training, validation, and testing) or override ``training_step``, ``validation_step``, and ``test_step`` -individually. These methods behave identically to PyTorch Lightning’s -`methods `__. - -Here is the pseudo code behind :class:`~flash.core.model.Task` step: - -.. code:: python - - def step(self, batch: Any, batch_idx: int) -> Any: - """ - The training/validation/test step. Override for custom behavior. - """ - x, y = batch - y_hat = self(x) - # compute the logs, loss and metrics as an output dictionary - ... - return output - - -3.a The DataSource API ----------------------- - -Now that we have defined our ``RegressionTask``, we need to load our data. We will define a custom ``NumpyDataSource`` -which extends :class:`~flash.core.data.data_source.DataSource`. The ``NumpyDataSource`` contains a ``load_data`` and -``predict_load_data`` methods which handle the loading of a sequence of dictionaries from the input numpy arrays. When -loading the train data (``if self.training:``), the ``NumpyDataSource`` sets the ``num_inputs`` attribute of the -optional ``dataset`` argument. Any attributes that are set on the optional ``dataset`` argument will also be set on the -generated ``dataset``. - -.. testcode:: custom_task - - class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: - if self.training: - dataset.num_inputs = data[0].shape[1] - return [{DefaultDataKeys.INPUT: x, DefaultDataKeys.TARGET: y} for x, y in zip(*data)] - - def predict_load_data(self, data: ND) -> List[Dict[str, Any]]: - return [{DefaultDataKeys.INPUT: x} for x in data] - - -3.b The Preprocess API ----------------------- - -Now that we have a :class:`~flash.core.data.data_source.DataSource` implementation, we can define our -:class:`~flash.core.data.process.Preprocess`. The :class:`~flash.core.data.process.Preprocess` object provides a series of hooks -that can be overridden with custom data processing logic and to which transforms can be attached. -It allows the user much more granular control over their data processing flow. - -.. note:: - - Why introduce :class:`~flash.core.data.process.Preprocess` ? - - The :class:`~flash.core.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or - to deploy the model in production environnement compared to a traditional - `Dataset `_. - - You can override ``predict_{hook_name}`` hooks or the ``default_predict_transforms`` to handle data processing logic - specific for inference. - -The recommended way to define a custom :class:`~flash.core.data.process.Preprocess` is as follows: - -- Define an ``__init__`` which accepts transform arguments. -- Pass these arguments through to ``super().__init__`` and specify the ``data_sources`` and the ``default_data_source``. - - ``data_sources`` gives the :class:`~flash.core.data.data_source.DataSource` objects that work with your :class:`~flash.core.data.process.Preprocess` as a mapping from data source name to :class:`~flash.core.data.data_source.DataSource`. The data source name can be any string, but for our purposes we can use ``NUMPY`` from :class:`~flash.core.data.data_source.DefaultDataSources`. - - ``default_data_source`` is the name of the data source to use by default when predicting. -- Override the ``get_state_dict`` and ``load_state_dict`` methods. These methods are used to save and load your :class:`~flash.core.data.process.Preprocess` from a checkpoint. -- Override the ``{train,val,test,predict}_default_transforms`` methods to specify the default transforms to use in each stage (these will be used if the transforms passed in the ``__init__`` are ``None``). - - Transforms are given as a mapping from hook name to callable transforms. You should use :class:`~flash.core.data.transforms.ApplyToKeys` to apply each transform only to specific keys in the data dictionary. - -.. testcode:: custom_task - - class NumpyPreprocess(Preprocess): - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - ): - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_sources={DefaultDataSources.NUMPY: NumpyDataSource()}, - default_data_source=DefaultDataSources.NUMPY, - ) - - @staticmethod - def to_float(x: Tensor): - return x.float() - - @staticmethod - def format_targets(x: Tensor): - return x.unsqueeze(0) - - @property - def to_tensor(self) -> Dict[str, Callable]: - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys( - DefaultDataKeys.INPUT, - torch.from_numpy, - self.to_float, - ), - ApplyToKeys( - DefaultDataKeys.TARGET, - torch.as_tensor, - self.to_float, - self.format_targets, - ), - ), - } - - def default_transforms(self) -> Optional[Dict[str, Callable]]: - return self.to_tensor - - def get_state_dict(self) -> Dict[str, Any]: - return self.transforms - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - return cls(*state_dict) - - -3.c The DataModule API ----------------------- - -Now that we have a :class:`~flash.core.data.process.Preprocess` which knows about the -:class:`~flash.core.data.data_source.DataSource` objects it supports, we just need to create a -:class:`~flash.core.data.data_module.DataModule` which has a reference to the ``preprocess_cls`` we want it to use. For any -data source whose name is in :class:`~flash.core.data.data_source.DefaultDataSources`, there is a standard -``DataModule.from_*`` method that provides the expected inputs. So in this case, there is the -:meth:`~flash.core.data.data_module.DataModule.from_numpy` that will use our numpy data source. - -.. testcode:: custom_task - - class NumpyDataModule(flash.DataModule): - - preprocess_cls = NumpyPreprocess - - -You now have a new customized Flash Task! Congratulations ! - -You can fit, finetune, validate and predict directly with those objects. - -4. Fitting ----------- - -For this task, here is how to fit the ``RegressionTask`` Task on ``scikit-learn`` `Diabetes -dataset `__. - -Like any Flash Task, we can fit our model using the ``flash.Trainer`` by -supplying the task itself, and the associated data: - -.. testcode:: custom_task - - x, y = datasets.load_diabetes(return_X_y=True) - datamodule = NumpyDataModule.from_numpy(x, y) - - model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) - - trainer = flash.Trainer( - max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count() - ) - trainer.fit(model, datamodule=datamodule) - - -.. testoutput:: custom_task - :hide: - - ... - - -5. Predicting -------------- - -With a trained model we can now perform inference. Here we will use a few examples from the test set of our data: - -.. testcode:: custom_task - - predict_data = np.array( - [ - [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], - ] - ) - - predictions = model.predict(predict_data) - print(predictions) - -We get the following output: - -.. testoutput:: custom_task - :hide: - - [tensor([...]), tensor([...]), tensor([...]), tensor([...]), tensor([...])] - -.. code-block:: - - [tensor([189.1198]), tensor([196.0839]), tensor([161.2461]), tensor([130.7591]), tensor([149.1780])] diff --git a/docs/source/general/backbones.rst b/docs/source/general/backbones.rst new file mode 100644 index 0000000000..247ddb4abb --- /dev/null +++ b/docs/source/general/backbones.rst @@ -0,0 +1,30 @@ +******************* +Backbones and Heads +******************* + +Backbones are the pre trained models that can be used with a task. +The backbones or heads that are available can be found by using the ``available_backbones`` and ``available_heads`` methods. + +To get the available backbones for a task like :class:`~flash.image.classification.model.ImageClassifier`, run: + +.. code-block:: python + + from flash.image import ImageClassifier + + # get the backbones available for ImageClassifier + backbones = ImageClassifier.available_backbones() + + # print the backbones + print(backbones) + +To get the available heads for a task like :class:`~flash.image.segmentation.model.SemanticSegmentation`, run: + +.. code-block:: python + + from flash.image import SemanticSegmentation + + # get the heads available for SemanticSegmentation + heads = SemanticSegmentation.available_heads() + + # print the heads + print(heads) diff --git a/docs/source/general/flash_zero.rst b/docs/source/general/flash_zero.rst index da3f73cbb3..282ad375c1 100644 --- a/docs/source/general/flash_zero.rst +++ b/docs/source/general/flash_zero.rst @@ -4,15 +4,104 @@ Flash Zero ********** -Flash Zero is a zero-code machine learning platform built directly into lightning-flash. -To get started and view the available tasks, run: +Flash Zero is a zero-code machine learning platform. +Here's an image classification example to illustrate with one of the dozens tasks available. + + +Flash Zero in 3 steps +_____________________ + +1. Select your task +=================== .. code-block:: bash - flash --help + flash {TASK_NAME} + +Here is the list of currently supported tasks. + +.. code-block:: bash + + audio_classification Classify audio spectrograms. + graph_classification Classify graphs. + image_classification Classify images. + instance_segmentation Segment object instances in images. + keypoint_detection Detect keypoints in images. + object_detection Detect objects in images. + pointcloud_detection Detect objects in point clouds. + pointcloud_segmentation Segment objects in point clouds. + question_answering Extractive Question Answering. + semantic_segmentation Segment objects in images. + speech_recognition Speech recognition. + style_transfer Image style transfer. + summarization Summarize text. + tabular_classification Classify tabular data. + text_classification Classify text. + translation Translate text. + video_classification Classify videos. + + +2. Pass in your own data +======================== + +.. code-block:: bash + + flash image_classification from_folders --train_folder data/hymenoptera_data/train + + +3. Modify the model and training parameters +=========================================== + +.. code-block:: bash + + flash image_classification --trainer.max_epochs 10 --model.backbone resnet50 from_folders --train_folder data/hymenoptera_data/train + +.. note:: + + The trainer and model arguments should be placed before the ``source`` subcommand. Here it is ``from_folders``. + + +Other Examples +______________ + +Image Object Detection +====================== + +To train an Object Detector on `COCO 2017 dataset `_, you could use the following command: + +.. code-block:: bash + + flash object_detection from_coco --train_folder data/coco128/images/train2017/ --train_ann_file data/coco128/annotations/instances_train2017.json --val_split .3 --batch_size 8 --num_workers 4 + + +Image Object Segmentation +========================= + +To train an Image Segmenter on `CARLA driving simulator dataset `_ + +.. code-block:: bash -Customize Trainer and Model arguments -_____________________________________ + flash semantic_segmentation from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21 + +Below is an example where the head, the backbone and its pretrained weights are customized. + +.. code-block:: bash + + flash semantic_segmentation --model.head fpn --model.backbone efficientnet-b0 --model.pretrained advprop from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21 + +Video Classification +==================== + +To train an Video Classifier on the `Kinetics dataset `_, you could use the following command: + + +.. code-block:: bash + + flash video_classification from_folders --train_folder data/kinetics/train/ --clip_duration 1 --num_workers 0 + + +CLI options +___________ Flash Zero is built on top of the `lightning CLI `_, so the trainer and @@ -29,8 +118,8 @@ To view all of the available options for a task, run: flash image_classification --help -Using Custom Data -_________________ +Using Your Own Data +___________________ Flash Zero works with your own data through subcommands. The available subcommands for each task are given at the bottom of their help pages (e.g. when running :code:`flash image-classification --help`). You can then use the required @@ -48,9 +137,54 @@ Now train with Flash Zero: flash image_classification from_folders --train_folder ./hymenoptera_data/train +Getting Help +____________ + +To find all available tasks, you can run: + +.. code-block:: bash + + flash --help + +This will output the following: + +.. code-block:: bash + + Commands: + audio_classification Classify audio spectrograms. + graph_classification Classify graphs. + image_classification Classify images. + instance_segmentation Segment object instances in images. + keypoint_detection Detect keypoints in images. + object_detection Detect objects in images. + pointcloud_detection Detect objects in point clouds. + pointcloud_segmentation Segment objects in point clouds. + question_answering Extractive Question Answering. + semantic_segmentation Segment objects in images. + speech_recognition Speech recognition. + style_transfer Image style transfer. + summarization Summarize text. + tabular_classification Classify tabular data. + text_classification Classify text. + translation Translate text. + video_classification Classify videos. + + +To get more information about a specific task, you can do the following: + +.. code-block:: bash + + flash image_classification --help + You can view the help page for each subcommand. For example, to view the options for training an image classifier from folders, you can run: .. code-block:: bash flash image_classification from_folders --help + +Finally, you can generate a `config.yaml` file from the client to ease parameters modification by running: + +.. code-block:: bash + + flash image_classification --print_config > config.yaml diff --git a/docs/source/reference/flash_to_production.rst b/docs/source/general/production.rst similarity index 65% rename from docs/source/reference/flash_to_production.rst rename to docs/source/general/production.rst index f9081c29a1..59e07b74c4 100644 --- a/docs/source/reference/flash_to_production.rst +++ b/docs/source/general/production.rst @@ -1,8 +1,11 @@ -######################## -From Flash to Production -######################## +################### +Flash in Production +################### -Flash makes it simple to deploy models in production. +Flash Serve +=========== + +Flash Serve makes model deployment simple. Server Side ^^^^^^^^^^^ @@ -20,4 +23,5 @@ Client Side :lines: 14- -Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine. +Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building the Flash Serve Engine. +Read all about it :ref:`here `. diff --git a/docs/source/general/serve.rst b/docs/source/general/serve.rst index 4e09ff6059..5ddab0c914 100644 --- a/docs/source/general/serve.rst +++ b/docs/source/general/serve.rst @@ -1,9 +1,9 @@ +.. _serve: + ########### Flash Serve ########### -.. _serve: - Flash Serve is a library to easily serve models in production. *********** diff --git a/docs/source/index.rst b/docs/source/index.rst index 5a6ab4687e..9fcace7dfe 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,9 +18,8 @@ Lightning Flash quickstart installation - custom_task - reference/flash_to_pl - reference/flash_to_production + general/flash_zero + general/production .. toctree:: :maxdepth: 1 @@ -32,8 +31,8 @@ Lightning Flash general/jit general/data general/registry - general/flash_zero general/serve + general/backbones .. toctree:: :maxdepth: 1 @@ -90,8 +89,11 @@ Lightning Flash :caption: Integrations integrations/providers + integrations/baal integrations/fiftyone integrations/icevision + integrations/learn2learn + integrations/vissl .. toctree:: :maxdepth: 1 diff --git a/docs/source/installation.md b/docs/source/installation.md index d306090c11..2c7bd86c49 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,25 +1,25 @@ # Installation -Flash is tested on Python 3.6+, and PyTorch 1.6. - ## Install with pip ```bash pip install lightning-flash ``` -Optionally, you can install Flash with extra packages for each domain or all domains. +Optionally, you can install Flash with extra packages for each domain. + +For a single domain, use: `pip install 'lightning-flash[{DOMAIN}]'`. ```bash pip install 'lightning-flash[image]' pip install 'lightning-flash[tabular]' pip install 'lightning-flash[text]' -pip install 'lightning-flash[video]' - -# image + video -pip install 'lightning-flash[vision]' +... +``` -# all features -pip install 'lightning-flash[all]' +For muliple domains, use: `pip install 'lightning-flash[{DOMAIN_1, DOMAIN_2, ...}]'`. +```bash +pip install 'lightning-flash[audio,image]' +... ``` For contributors, please install Flash with packages for testing Flash and building docs. @@ -32,8 +32,23 @@ cd lightning-flash pip install -e '.[dev]' ``` +## Install with conda + +Flash is available via conda forge. Install it with: +```bash +conda install -c conda-forge lightning-flash +``` + ## Install from source +You can install Flash from source without any domain specific dependencies with: +```bash +pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git' +``` + +To install Flash with domain dependencies, use: ```bash -pip install git+https://github.com/PyTorchLightning/lightning-flash.git +pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]' ``` + +You can again install dependencies for multiple domains by separating them with commas as above. diff --git a/docs/source/integrations/baal.rst b/docs/source/integrations/baal.rst new file mode 100644 index 0000000000..4e6a61063e --- /dev/null +++ b/docs/source/integrations/baal.rst @@ -0,0 +1,32 @@ +.. _baal: + +#### +BaaL +#### + +The framework `Bayesian Active Learning (BaaL) `_ is an active learning +library developed at `ElementAI `_. + +.. raw:: html + +
+ +
+ + +Active Learning is a sub-field in AI, focusing on adding a human in the learning loop. +The most uncertain samples will be labelled by the human to accelerate the model training cycle. + +.. raw:: html + +
+ +

Credit to ElementAI / Baal Team for creating this diagram flow

+
+
+ +With its integration within Flash, the Active Learning process is simpler than ever before. + +.. literalinclude:: ../../../flash_examples/integrations/baal/image_classification_active_learning.py + :language: python + :lines: 14- diff --git a/docs/source/integrations/learn2learn.rst b/docs/source/integrations/learn2learn.rst new file mode 100644 index 0000000000..18ae188a0a --- /dev/null +++ b/docs/source/integrations/learn2learn.rst @@ -0,0 +1,81 @@ +.. _learn2learn: + +########### +Learn2Learn +########### + +`Learn2Learn `__ is a software library for meta-learning research by `Sébastien M. R. Arnold and al.` (Aug 2020) + +.. raw:: html + +
+ +
+
+ + + +What is Meta-Learning and why you should care? +---------------------------------------------- + +Humans can distinguish between new objects with little or no training data, +However, machine learning models often require thousands, millions, billions of annotated data samples +to achieve good performance while extrapolating their learned knowledge on unseen objects. + +A machine learning model which could learn or learn to learn from only few new samples (K-shot learning) would have tremendous applications +once deployed in production. +In an extreme case, a model performing 1-shot or 0-shot learning could be the source of new kind of AI applications. + +Meta-Learning is a sub-field of AI dedicated to the study of few-shot learning algorithms. +This is often characterized as teaching deep learning models to learn with only a few labeled data. +The goal is to repeatedly learn from K-shot examples during training that match the structure of the final K-shot used in production. +It is important to note that the K-shot example seen in production are very likely to be completely out-of-distribution with new objects. + + +How does Meta-Learning work? +---------------------------- + +In meta-learning, the model is trained over multiple meta tasks. +A meta task is the smallest unit of data and it represents the data available to the model once in its deployment environment. +By doing so, we can optimise the model and get higher results. + +.. raw:: html + +
+ +
+
+ +For image classification, a meta task is comprised of shot + query elements for each class. +The shots samples are used to adapt the parameters and the queries ones to update the original model weights. +The classes used in the validation and testing shouldn't be present within the training dataset, +as the goal is to optimise the model performance on out-of-distribution (OOD) data with little label data. + +When training the model with the meta-learning algorithm, +the model will average its gradients over meta_batch_size meta tasks before performing an optimizer step. +Traditionally, an meta epoch is composed of multiple meta batch. + +Use Meta-Learning with Flash +---------------------------- + +With its integration within Flash, Meta Learning has never been simpler. +Flash takes care of all the hard work: the tasks sampling, meta optimizer update, distributed training, etc... + +.. note:: + + The users requires to provide a training dataset and testing dataset with no overlapping classes. + Flash doesn't support this feature out-of-the box. + +Once done, the users are left to play the hyper-parameters associated with the meta-learning algorithm. + +Here is an example using `miniImageNet dataset `_ containing 100 classes divided into 64 training, 16 validation, and 20 test classes. + +.. literalinclude:: ../../../flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py + :language: python + :lines: 15- + + +You can read their paper `Learn2Learn: A Library for Meta-Learning Research `_. + +And don't forget to cite `Learn2Learn `__ repository in your academic publications. +Find their Biblex on their repository. diff --git a/docs/source/integrations/vissl.rst b/docs/source/integrations/vissl.rst new file mode 100644 index 0000000000..a5a64f6fdf --- /dev/null +++ b/docs/source/integrations/vissl.rst @@ -0,0 +1,33 @@ +.. _vissl: + +##### +VISSL +##### + +`VISSL `__ is a library from Facebook AI Research for state-of-the-art self-supervised learning. +We integrate VISSL models and algorithms into Flash with the :ref:`image embedder ` task. + +Using VISSL with Flash +---------------------- + +The ImageEmbedder task in Flash can be configured with different backbones, projection heads, image transforms and loss functions so that you can train your feature extractor using a SOTA SSL method. + +.. code-block:: python + + from flash.image import ImageEmbedder + + embedder = ImageEmbedder( + backbone="resnet", + training_strategy="barlow_twins", + head="simclr_head", + pretraining_transform="barlow_twins_transform", + training_strategy_kwargs={"latent_embedding_dim": 256, "dims": [2048, 2048, 256]}, + pretraining_transform_kwargs={"size_crops": [196]}, + ) + +The user can pass arguments to the training strategy, image transforms and backbones using the optional dictionary arguments the ImageEmbedder task accepts. +The training strategies club together the projection head, the loss function as well as VISSL hooks for a particular algorithm and the arguments to customize these can passed via ``training_strategy_kwargs``. +As an example, in the above code block, the ``latent_embedding_dim`` is an argument to the BarlowTwins loss function from VISSL, while the ``dims`` argument configures the projection head to output 256 dim vectors for the loss function. + +If you find VISSL integration in Flash useful for your research, please don't forget to cite us and the VISSL library. +You can find our bibtex on `Flash `__ and VISSL's bibxtex on their `github `__ page. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 85cf5b6f53..36c3ba3475 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -52,10 +52,6 @@ Standard best practices ^^^^^^^^^^^^^^^^^^^^^^^ Flash tasks implement the standard best practices for a variety of different models and domains, to save you time digging through different implementations. Flash abstracts even more details than Lightning, allowing deep learning experts to share their tips and tricks for solving scoped deep learning problems. -.. tip:: - - Read :doc:`here ` to understand when to use Flash vs Lightning. - ------ Tasks diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index 482c6c9d2d..cb7620143c 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Audio Classification :card_description: Learn to classify audio spectrogram images with Flash and build an example classifier for the UrbanSound8k data set. - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.jpg + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg :tags: Audio,Classification .. _audio_classification: diff --git a/docs/source/reference/flash_to_pl.rst b/docs/source/reference/flash_to_pl.rst deleted file mode 100644 index 39dd990600..0000000000 --- a/docs/source/reference/flash_to_pl.rst +++ /dev/null @@ -1,121 +0,0 @@ -####################### -From Flash to Lightning -####################### - -Flash is built on top of `PyTorch Lightning -`_ to abstract away the unnecessary boilerplate for: - -- Data science -- Kaggle -- Business use cases -- Applied research - -Flash is a HIGH level library and Lightning is a LOW level library. - -- Flash (high-level) -- Lightning (medium-level) -- PyTorch (low-level) - -As the complexity increases or decreases, users can move between Flash and Lightning seamlessly to find the -level of abstraction that works for them. - -.. list-table:: Abstraction levels - :widths: 20 20 20 20 40 - :header-rows: 1 - - * - Approach - - Flexibility - - Minimum DL Expertise level - - PyTorch Knowledge - - Use cases - * - Using an out-of-the-box task - - Low - - Novice+ - - Low+ - - Fast baseline, Data Science, Analysis, Applied Research - * - Using the Generic Task - - Medium - - Intermediate+ - - Intermediate+ - - Fast baseline, data science - * - Building a custom task - - High - - Intermediate+ - - Intermediate+ - - Fast baseline, custom business context, applied research - * - Building a LightningModule - - Ultimate (organized PyTorch) - - Expert+ - - Expert+ - - For anything you can do with PyTorch, AI research (academic and corporate) - ------- - -**************************** -Using an out-of-the-box task -**************************** -Tasks can come from a variety of places: - -- Flash -- Other Lightning-based libraries -- Your own library - -Using a task requires almost zero knowledge of deep learning and PyTorch. The focus is on solving a problem as quickly as possible. -This is great for: - -- data science -- analysis -- applied research - ------- - -********************** -Using the Generic Task -********************** -If you encounter a problem that does not have a matching task, you can use the generic task. However, this does -require a bit of PyTorch knowledge but not a lot of knowledge over all the details of deep learning. - -This is great for: - -- data science -- kaggle baselines -- a quick baseline -- applied research -- learning about deep learning - -.. note:: If you've used something like Keras, this is the most similar level of abstraction. - ------- - -********************** -Building a custom task -********************** -If you're feeling adventurous and there isn't an out-of-the-box task for a particular applied problem, consider -building your own task. This requires a decent amount of PyTorch knowledge, but not too much because tasks are -LightningModules that already abstract a lot of the details for you. - -This is great for: - -- data science -- researchers building for corporate data science teams -- applied research -- custom business context - -.. note:: In a company setting, a good setup here is to have your own Flash-like library with tasks contextualized with your business problems. - ------- - -************************** -Building a LightningModule -************************** -Once you've reached the threshold of flexibility offered by Flash, it's time to move to a LightningModule directly. -LightningModule is organized PyTorch but gives you the same flexibility. However, you must already know PyTorch -fairly well and be comfortable with at least basic deep learning concepts. - -This is great for: - -- experts -- academic AI research -- corporate AI research -- advanced applied research -- publishing papers diff --git a/docs/source/reference/graph_classification.rst b/docs/source/reference/graph_classification.rst index e0b6548b05..b1b4e29fae 100644 --- a/docs/source/reference/graph_classification.rst +++ b/docs/source/reference/graph_classification.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Graph Classification :card_description: Learn to classify graphs with Flash and build an example classifier for the KKI data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg :tags: Graph,Classification .. _graph_classification: diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 8562f8ebcf..93e2983a4e 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Image Classification :card_description: Learn to classify images with Flash and build an example Ants / Bees classifier. - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.jpg + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg :tags: Image,Classification .. _image_classification: diff --git a/docs/source/reference/image_classification_multi_label.rst b/docs/source/reference/image_classification_multi_label.rst index 2151823fbf..0fa5884946 100644 --- a/docs/source/reference/image_classification_multi_label.rst +++ b/docs/source/reference/image_classification_multi_label.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Multi-label Image Classification :card_description: Learn to classify images in a multi-label setting with Flash and build an example classifier which predicts genres from movie posters. - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification_multi_label.jpg + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/multi_label_image_classification.svg :tags: Image,Multi-label,Classification .. _image_classification_multi_label: diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index a738ea7f91..5d3abc8efb 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Image Embedder :card_description: Learn to generate embeddings from images with Flash. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_embedder.svg :tags: Image,Embedding .. _image_embedder: @@ -16,15 +17,24 @@ The Task Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification. +The :class:`~flash.image.embedding.model.ImageEmbedder` internally relies on `VISSL `_. + ------ ******* Example ******* -Let's see how to use the :class:`~flash.image.embedding.model.ImageEmbedder` with a pretrained backbone to obtain feature vectors from the hymenoptera data. -Once we've downloaded the data, we create the :class:`~flash.image.embedding.model.ImageEmbedder` and perform inference (obtaining feature vectors / embeddings) using :meth:`~flash.image.embedding.model.ImageEmbedder.predict`. -Here's the full example: +Let's see how to configure a training strategy for the :class:`~flash.image.embedding.model.ImageEmbedder` task. +A vanilla :class:`~flash.core.data.data_module.DataModule` object be created using standard Datasets as shown below. +Then the user can configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``. +There are options provided to send additional arguments to config selections. +This task can now be sent to the ``fit()`` method of :class:`~flash.core.trainer.Trainer`. + +.. note:: + + A lot of VISSL loss functions use hard-coded ``torch.distributed`` methods. The user is suggested to use ``accelerator=ddp`` even with a single GPU. + Only ``barlow_twins`` training strategy works on the CPU. All other loss functions are configured to work on GPUs. .. literalinclude:: ../../../flash_examples/image_embedder.py :language: python diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst index 5a05d45014..1a0f81503f 100644 --- a/docs/source/reference/instance_segmentation.rst +++ b/docs/source/reference/instance_segmentation.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Instance Segmentation :card_description: Learn to segment objects in images with Flash and build a model for segmenting images of pets. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/instance_segmentation.svg :tags: Image,Segmentation,Detection .. _instance_segmentation: diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst index 6ce0610d0e..53cd0e9035 100644 --- a/docs/source/reference/keypoint_detection.rst +++ b/docs/source/reference/keypoint_detection.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Keypoint Detection :card_description: Learn to detect keypoints in images with Flash and build a network to detect facial keypoints with the BIWI data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/keypoint_detection.svg :tags: Image,Keypoint,Detection .. _keypoint_detection: diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 6db6b85272..dd70e4bfed 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Object Detection :card_description: Learn to detect objects in images with Flash and build an example detector with the COCO data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg :tags: Image,Detection .. _object_detection: @@ -72,3 +73,35 @@ To view configuration options and options for running the object detector with y .. code-block:: bash flash object_detection --help + +------ + +********************** +Custom Transformations +********************** + +Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. +The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline. +For object-detection tasks, you can leverage the transformations from `Albumentations `__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`. + +.. code-block:: python + + import albumentations as alb + from icevision.tfms import A + + from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter + from flash.image import ObjectDetectionData + + train_transform = { + "pre_tensor_transform": transforms.IceVisionTransformAdapter( + [*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()] + ) + } + + datamodule = ObjectDetectionData.from_coco( + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + val_split=0.1, + image_size=128, + train_transform=train_transform, + ) diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst index 94cb6327d5..0bdccf0ee5 100644 --- a/docs/source/reference/pointcloud_object_detection.rst +++ b/docs/source/reference/pointcloud_object_detection.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Point Cloud Object Detection :card_description: Learn to detect objects in point clouds with Flash and build an example detector with the KITTI data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/point_cloud_object_detection.svg :tags: Point-Cloud,Detection .. _pointcloud_object_detection: diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst index f02aeb6cbc..28a0c31623 100644 --- a/docs/source/reference/pointcloud_segmentation.rst +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Point Cloud Segmentation :card_description: Learn to segment objects in point clouds with Flash and build an example network with the KITTI data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/point_cloud_segmentation.svg :tags: Point-Cloud,Segmentation .. _pointcloud_segmentation: diff --git a/docs/source/reference/question_answering.rst b/docs/source/reference/question_answering.rst index 3030840b82..b264b83823 100644 --- a/docs/source/reference/question_answering.rst +++ b/docs/source/reference/question_answering.rst @@ -1,3 +1,9 @@ +.. customcarditem:: + :header: Extractive Question Answering + :card_description: Learn to answer questions pertaining to some known textual context. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: NLP,Text + .. _question_answering: ################## diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 1131ca55ed..c88c67c8f8 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Semantic Segmentation :card_description: Learn to segment images with Flash and build a model which segments images from the CARLA driving simulator. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/semantic_segmentation.svg :tags: Image,Segmentation .. _semantic_segmentation: diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst index 35677f3a01..56ec4be145 100644 --- a/docs/source/reference/speech_recognition.rst +++ b/docs/source/reference/speech_recognition.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Speech Recognition :card_description: Learn to recognize speech Flash (speech-to-text) and train a model on the TIMIT corpus. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/speech_recognition.svg :tags: Audio,Speech-Recognition,NLP .. _speech_recognition: diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index fae2a4d22a..15f23c337e 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Style Transfer :card_description: Learn about image style transfer with Flash and build an example which transfers style from The Starry Night to images from the COCO data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/style_transfer.svg :tags: Image,Style-Transfer .. _style_transfer: diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index ed4cb3bbf0..e402957f75 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Summarization :card_description: Learn to summarize long passages of text with Flash and build an example model with the XSUM data set. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg :tags: Text,Summarization,NLP .. _summarization: diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index eb33db82fc..b85939ce83 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Tabular Classification :card_description: Learn to classify tabular records with Flash and build an example model to predict survival rates on the Titanic. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg :tags: Tabular,Classification .. _tabular_classification: diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index e2142819b3..53f947b47a 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Text Classification :card_description: Learn to classify text with Flash and build an example sentiment analyser for IMDB reviews. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg :tags: Text,Classification,NLP .. _text_classification: diff --git a/docs/source/reference/text_classification_multi_label.rst b/docs/source/reference/text_classification_multi_label.rst index 468d2b4f7d..0ef245ceab 100644 --- a/docs/source/reference/text_classification_multi_label.rst +++ b/docs/source/reference/text_classification_multi_label.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Multi-label Text Classification :card_description: Learn to classify text in a multi-label setting with Flash and build an example comment toxicity classifier. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/multi_label_text_classification.svg :tags: Text,Multi-label,Classification,NLP .. _text_classification_multi_label: diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index bc37ad67eb..9b0c67c802 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Translation :card_description: Learn to translate text with Flash and build an example model which translates from English to Romanian. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/translation.svg :tags: Text,Translation,NLP .. _translation: diff --git a/docs/source/reference/video_classification.rst b/docs/source/reference/video_classification.rst index 0e1145cc28..bbcb15ad84 100644 --- a/docs/source/reference/video_classification.rst +++ b/docs/source/reference/video_classification.rst @@ -1,6 +1,7 @@ .. customcarditem:: :header: Video Classification :card_description: Learn to classify videos with Flash and build an example action classifier. + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/video_classification.svg :tags: Video,Classification .. _video_classification: diff --git a/flash/__about__.py b/flash/__about__.py index 02e7453491..0ae18a3a6e 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0" +__version__ = "0.5.1dev" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/flash/audio/classification/cli.py b/flash/audio/classification/cli.py index c198a99239..c69b1e540c 100644 --- a/flash/audio/classification/cli.py +++ b/flash/audio/classification/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.audio import AudioClassificationData from flash.core.data.utils import download_data @@ -23,7 +22,7 @@ def from_urban8k( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> AudioClassificationData: """Downloads and loads the Urban 8k sounds images data set.""" diff --git a/flash/audio/speech_recognition/cli.py b/flash/audio/speech_recognition/cli.py index 9bbdb48df8..f8a7ad26dd 100644 --- a/flash/audio/speech_recognition/cli.py +++ b/flash/audio/speech_recognition/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.audio import SpeechRecognition, SpeechRecognitionData from flash.core.data.utils import download_data @@ -23,7 +22,7 @@ def from_timit( val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> SpeechRecognitionData: """Downloads and loads the timit data set.""" diff --git a/flash/core/adapter.py b/flash/core/adapter.py index c7557b1977..a9a0b84f99 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -14,6 +14,7 @@ from abc import abstractmethod from typing import Any, Callable, Optional +import torch.jit from torch import nn from torch.utils.data import DataLoader, Sampler @@ -59,6 +60,10 @@ def test_epoch_end(self, outputs) -> None: pass +def identity_collate_fn(x): + return x + + class AdapterTask(Task): """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` and forwards all of the hooks. @@ -73,11 +78,12 @@ def __init__(self, adapter: Adapter, **kwargs): self.adapter = adapter + @torch.jit.unused @property def backbone(self) -> nn.Module: return self.adapter.backbone - def forward(self, x: Any) -> Any: + def forward(self, x: torch.Tensor) -> Any: return self.adapter.forward(x) def training_step(self, batch: Any, batch_idx: int) -> Any: @@ -104,6 +110,7 @@ def test_epoch_end(self, outputs) -> None: def process_train_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -113,12 +120,21 @@ def process_train_dataset( sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_train_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + dataset, + trainer, + batch_size, + num_workers, + pin_memory, + collate_fn, + shuffle, + drop_last, + sampler, ) def process_val_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -128,12 +144,21 @@ def process_val_dataset( sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_val_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + dataset, + trainer, + batch_size, + num_workers, + pin_memory, + collate_fn, + shuffle, + drop_last, + sampler, ) def process_test_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -143,7 +168,15 @@ def process_test_dataset( sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_test_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + dataset, + trainer, + batch_size, + num_workers, + pin_memory, + collate_fn, + shuffle, + drop_last, + sampler, ) def process_predict_dataset( @@ -152,11 +185,18 @@ def process_predict_dataset( batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, - collate_fn: Callable = lambda x: x, + collate_fn: Callable = identity_collate_fn, shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, ) -> DataLoader: return self.adapter.process_predict_dataset( - dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, ) diff --git a/flash/core/classification.py b/flash/core/classification.py index b11e714528..5dacef2bb8 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -18,6 +18,7 @@ import torchmetrics from pytorch_lightning.utilities import rank_zero_warn +from flash.core.adapter import AdapterTask from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.data.process import Serializer from flash.core.model import Task @@ -37,7 +38,29 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. return F.binary_cross_entropy_with_logits(x, y.float()) -class ClassificationTask(Task): +class ClassificationMixin: + def _build( + self, + num_classes: Optional[int] = None, + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + multi_label: bool = False, + ): + if metrics is None: + metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy() + + if loss_fn is None: + loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy + + return metrics, loss_fn + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + if getattr(self.hparams, "multi_label", False): + return torch.sigmoid(x) + return torch.softmax(x, dim=1) + + +class ClassificationTask(Task, ClassificationMixin): def __init__( self, *args, @@ -48,11 +71,9 @@ def __init__( serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs, ) -> None: - if metrics is None: - metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy() - if loss_fn is None: - loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy + metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + super().__init__( *args, loss_fn=loss_fn, @@ -61,11 +82,28 @@ def __init__( **kwargs, ) - def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: - if getattr(self.hparams, "multi_label", False): - return torch.sigmoid(x) - # we'll assume that the data always comes as `(B, C, ...)` - return torch.softmax(x, dim=1) + +class ClassificationAdapterTask(AdapterTask, ClassificationMixin): + def __init__( + self, + *args, + num_classes: Optional[int] = None, + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs, + ) -> None: + + metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + + super().__init__( + *args, + loss_fn=loss_fn, + metrics=metrics, + serializer=serializer or Classes(multi_label=multi_label), + **kwargs, + ) class ClassificationSerializer(Serializer): diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 7f94c039b7..054efbfb6d 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import platform from typing import ( Any, Callable, @@ -98,7 +97,7 @@ def __init__( data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, ) -> None: @@ -138,13 +137,10 @@ def __init__( self.batch_size = batch_size - # TODO: figure out best solution for setting num_workers if num_workers is None: - if platform.system() in ("Darwin", "Windows"): - num_workers = 0 - else: - num_workers = os.cpu_count() + num_workers = 0 self.num_workers = num_workers + self.sampler = sampler self.set_running_stages() @@ -302,6 +298,7 @@ def _train_dataloader(self) -> DataLoader: if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_train_dataset( train_ds, + trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, @@ -330,6 +327,7 @@ def _val_dataloader(self) -> DataLoader: if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_val_dataset( val_ds, + trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, @@ -352,6 +350,7 @@ def _test_dataloader(self) -> DataLoader: if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_test_dataset( test_ds, + trainer=self.trainer, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, @@ -368,6 +367,7 @@ def _test_dataloader(self) -> DataLoader: def _predict_dataloader(self) -> DataLoader: predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds + if isinstance(predict_ds, IterableAutoDataset): batch_size = self.batch_size else: @@ -468,7 +468,7 @@ def from_data_source( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -569,7 +569,7 @@ def from_folders( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -643,7 +643,7 @@ def from_files( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -720,7 +720,7 @@ def from_tensors( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -807,7 +807,7 @@ def from_numpy( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -893,7 +893,7 @@ def from_json( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, field: Optional[str] = None, **preprocess_kwargs: Any, @@ -1003,7 +1003,7 @@ def from_csv( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -1087,7 +1087,7 @@ def from_datasets( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -1168,7 +1168,7 @@ def from_fiftyone( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 41ff53e8be..cd0a16fada 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -534,7 +534,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin if isinstance(dl_args["collate_fn"], _Preprocessor): dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn - if isinstance(dl_args["dataset"], IterableAutoDataset): + if isinstance(dl_args["dataset"], (IterableAutoDataset, IterableDataset)): del dl_args["sampler"] del dl_args["batch_sampler"] diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index e95e29ec2d..fb4260ed89 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -465,6 +465,9 @@ def load_data( data = make_dataset(data, class_to_idx, extensions=self.extensions) return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] + elif dataset is not None: + dataset.num_classes = len(np.unique(data[1])) + return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), @@ -599,13 +602,12 @@ def load_data( } for _, row in data_frame.iterrows() ] - else: - return [ - { - DefaultDataKeys.INPUT: row[input_key], - } - for _, row in data_frame.iterrows() - ] + return [ + { + DefaultDataKeys.INPUT: row[input_key], + } + for _, row in data_frame.iterrows() + ] def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: # TODO: simplify this duplicated code from PathsDataSource @@ -624,6 +626,16 @@ class TensorDataSource(SequenceDataSource[torch.Tensor]): """The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to :meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects.""" + def load_data( + self, + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + if len(data) == 2: + dataset.num_classes = len(torch.unique(torch.tensor(data[1]))) + return super().load_data(data, dataset) + class NumpyDataSource(SequenceDataSource[np.ndarray]): """The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 5ebb4d15b0..3b4a8d901c 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -342,17 +342,22 @@ def default_transforms() -> Optional[Dict[str, Callable]]: """ return None + def _apply_sample_transform(self, sample: Any) -> Any: + if isinstance(sample, list): + return [self.current_transform(s) for s in sample] + return self.current_transform(sample) + def pre_tensor_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" - return self.current_transform(sample) + return self._apply_sample_transform(sample) def to_tensor_transform(self, sample: Any) -> Tensor: """Transforms to convert single object to a tensor.""" - return self.current_transform(sample) + return self._apply_sample_transform(sample) def post_tensor_transform(self, sample: Tensor) -> Tensor: """Transforms to apply on a tensor.""" - return self.current_transform(sample) + return self._apply_sample_transform(sample) def per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index aad996fdfe..42a5d40fcb 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -111,6 +111,8 @@ def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: This function removes that dimension and then applies ``torch.utils.data._utils.collate.default_collate``. """ + if len(samples) == 1 and isinstance(samples[0], list): + samples = samples[0] for sample in samples: for key in sample.keys(): if torch.is_tensor(sample[key]) and sample[key].ndim == 4: diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 83be7c3848..1e6c7d48a9 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader, Sampler +import flash from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys @@ -91,6 +92,7 @@ def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = def process_train_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -114,6 +116,7 @@ def process_train_dataset( def process_val_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -137,6 +140,7 @@ def process_val_dataset( def process_test_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index ee1dfe1ed5..b91fa0dcca 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -68,8 +68,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) records = parser.parse(data_splitter=SingleSplitSplitter()) return [{DefaultDataKeys.INPUT: record} for record in records[0]] - else: - raise ValueError("The parser argument must be provided.") + raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: result = super().predict_load_data(data, dataset) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 3859bfa2ff..5619dfd5af 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -196,7 +196,13 @@ def from_icevision_predictions(predictions: List["Prediction"]): class IceVisionTransformAdapter(nn.Module): - def __init__(self, transform): + """ + Args: + transform: list of transformation functions to apply + + """ + + def __init__(self, transform: List[Callable]): super().__init__() self.transform = A.Adapter(transform) diff --git a/flash/core/integrations/vissl/transforms/utilities.py b/flash/core/integrations/vissl/transforms/utilities.py deleted file mode 100644 index 3590011947..0000000000 --- a/flash/core/integrations/vissl/transforms/utilities.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from flash.core.data.data_source import DefaultDataKeys - - -def vissl_collate_fn(samples): - """Custom collate function for VISSL integration. - - Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT - """ - result = [] - - for batch_ele in samples: - _batch_ele_dict = {} - _batch_ele_dict.update(batch_ele) - _batch_ele_dict[DefaultDataKeys.INPUT] = -1 - - result.append(_batch_ele_dict) - - result = torch.utils.data._utils.collate.default_collate(result) - - inputs = [[] for _ in range(len(samples[0][DefaultDataKeys.INPUT]))] - for batch_ele in samples: - multi_crop_imgs = batch_ele[DefaultDataKeys.INPUT] - - for idx, crop in enumerate(multi_crop_imgs): - inputs[idx].append(crop) - - for idx, ele in enumerate(inputs): - inputs[idx] = torch.stack(ele) - - result[DefaultDataKeys.INPUT] = inputs - - return result diff --git a/flash/core/model.py b/flash/core/model.py index e555b86047..891bd5612a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -26,6 +26,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -68,7 +69,7 @@ def __init__(self): self._children = [] # TODO: create enum values to define what are the exact states - self._data_pipeline_state: Optional[DataPipelineState] = None + self._data_pipeline_state: DataPipelineState = DataPipelineState() # model own internal state shared with the data pipeline. self._state: Dict[Type[ProcessState], ProcessState] = {} @@ -118,6 +119,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = True, ) -> DataLoader: return DataLoader( dataset, @@ -128,11 +130,13 @@ def _process_dataset( drop_last=drop_last, sampler=sampler, collate_fn=collate_fn, + persistent_workers=persistent_workers, ) def process_train_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -150,11 +154,13 @@ def process_train_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=num_workers > 0, ) def process_val_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -172,11 +178,13 @@ def process_val_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=num_workers > 0, ) def process_test_dataset( self, dataset: BaseAutoDataset, + trainer: "flash.Trainer", batch_size: int, num_workers: int, pin_memory: bool, @@ -194,6 +202,7 @@ def process_test_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=num_workers > 0, ) def process_predict_dataset( @@ -216,6 +225,7 @@ def process_predict_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=False, ) @@ -267,6 +277,19 @@ def __new__(mcs, *args, **kwargs): return result +class OutputKeys(LightningEnum): + """The ``OutputKeys`` enum contains the keys that are used internally by the ``Task`` when handling outputs.""" + + OUTPUT = "y_hat" + TARGET = "y" + LOGS = "logs" + LOSS = "loss" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. @@ -331,11 +354,11 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: x, y = batch y_hat = self(x) y, y_hat = self.apply_filtering(y, y_hat) - output = {"y_hat": y_hat} - y_hat = self.to_loss_format(output["y_hat"]) + output = {OutputKeys.OUTPUT: y_hat} + y_hat = self.to_loss_format(output[OutputKeys.OUTPUT]) losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} - y_hat = self.to_metrics_format(output["y_hat"]) + y_hat = self.to_metrics_format(output[OutputKeys.OUTPUT]) logs = {} @@ -350,9 +373,9 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: logs["total_loss"] = sum(losses.values()) return logs["total_loss"], logs - output["loss"] = self.compute_loss(losses) - output["logs"] = self.compute_logs(logs, losses) - output["y"] = y + output[OutputKeys.LOSS] = self.compute_loss(losses) + output[OutputKeys.LOGS] = self.compute_logs(logs, losses) + output[OutputKeys.TARGET] = y return output def compute_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -380,16 +403,31 @@ def forward(self, x: Any) -> Any: def training_step(self, batch: Any, batch_idx: int) -> Any: output = self.step(batch, batch_idx, self.train_metrics) - self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True) - return output["loss"] + self.log_dict( + {f"train_{k}": v for k, v in output[OutputKeys.LOGS].items()}, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return output[OutputKeys.LOSS] def validation_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx, self.val_metrics) - self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) + self.log_dict( + {f"val_{k}": v for k, v in output[OutputKeys.LOGS].items()}, + on_step=False, + on_epoch=True, + prog_bar=True, + ) def test_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx, self.val_metrics) - self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) + self.log_dict( + {f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()}, + on_step=False, + on_epoch=True, + prog_bar=True, + ) @predict_context def predict( @@ -422,6 +460,7 @@ def predict( else: x = self.transfer_batch_to_device(x, self.device) x = data_pipeline.device_preprocessor(running_stage)(x) + x = x[0] if isinstance(x, list) else x predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict` predictions = data_pipeline.postprocessor(running_stage)(predictions) return predictions diff --git a/flash/core/optimizers/lr_scheduler.py b/flash/core/optimizers/lr_scheduler.py index 187f6c495f..5cf3fcc278 100644 --- a/flash/core/optimizers/lr_scheduler.py +++ b/flash/core/optimizers/lr_scheduler.py @@ -93,14 +93,14 @@ def get_lr(self) -> List[float]: if self.last_epoch == self.warmup_epochs: return self.base_lrs - elif self.last_epoch == 0: + if self.last_epoch == 0: return [self.warmup_start_lr] * len(self.base_lrs) - elif self.last_epoch < self.warmup_epochs: + if self.last_epoch < self.warmup_epochs: return [ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] - elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 diff --git a/flash/core/registry.py b/flash/core/registry.py index 714b2a3537..a454948e04 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -88,7 +88,7 @@ def get( """ matches = [e for e in self.functions if key == e["name"]] if not matches: - raise KeyError(f"Key: {key} is not in {type(self).__name__}") + raise KeyError(f"Key: {key} is not in {type(self).__name__}. Available keys: {self.available_keys()}") if metadata: matches = [m for m in matches if metadata.items() <= m["metadata"].items()] @@ -111,7 +111,11 @@ def _register_function( if not callable(fn): raise MisconfigurationException(f"You can only register a callable, found: {fn}") - name = name or fn.__name__ + if name is None: + if hasattr(fn, "func"): + name = fn.func.__name__ + else: + name = fn.__name__ if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") diff --git a/flash/core/trainer.py b/flash/core/trainer.py index eeb030db11..02adfeae85 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -21,6 +21,7 @@ from pytorch_lightning import LightningDataModule, LightningModule from pytorch_lightning import Trainer as PlTrainer from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -101,6 +102,28 @@ def run_sanity_check(self, ref_model): if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE: ref_model.run_serve_sanity_check() + # TODO @(tchaton) remove `reset_train_val_dataloaders` from run_train function + def _run_train(self) -> None: + self._pre_training_routine() + + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + + self._run_sanity_check(self.lightning_module) + + # enable train mode + self.model.train() + torch.set_grad_enabled(True) + + # reload data when needed + model = self.lightning_module + + if isinstance(self.fit_loop, FitLoop): + self.reset_train_val_dataloaders(model) + + self.fit_loop.trainer = self + self.fit_loop.run() + def fit( self, model: LightningModule, diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 88a1840830..f138eaf37e 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -17,7 +17,6 @@ import types from importlib.util import find_spec from typing import List, Union -from warnings import warn from pkg_resources import DistributionNotFound @@ -98,26 +97,18 @@ def _compare_version(package: str, op, version) -> bool: _DATASETS_AVAILABLE = _module_available("datasets") _ICEVISION_AVAILABLE = _module_available("icevision") _ICEDATA_AVAILABLE = _module_available("icedata") +_LEARN2LEARN_AVAILABLE = _module_available("learn2learn") and _compare_version("learn2learn", operator.ge, "0.1.6") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") _VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision") +_ALBUMENTATIONS_AVAILABLE = _module_available("albumentations") +_BAAL_AVAILABLE = _module_available("baal") if _PIL_AVAILABLE: - from PIL import Image + from PIL import Image # noqa: F401 else: - class MetaImage(type): - def __init__(cls, name, bases, dct): - super().__init__(name, bases, dct) - - cls._Image = None - - @property - def Image(cls): - warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".") - return cls._Image - - class Image(metaclass=MetaImage): - pass + class Image: + Image = object if Version: @@ -191,8 +182,7 @@ def wrapper(*args, **kwargs): ) return wrapper - else: - return func + return func return decorator diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index f25c402683..a5bb749246 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -39,8 +39,10 @@ def __str__(self): _SEGMENTATION_MODELS = Provider( "qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch" ) +_LEARN2LEARN = Provider("learnables/learn2learn", "https://github.com/learnables/learn2learn") _PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") _HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers") _FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") _OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") +_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl") diff --git a/flash/graph/classification/cli.py b/flash/graph/classification/cli.py index f79af259d8..d8fd18702c 100644 --- a/flash/graph/classification/cli.py +++ b/flash/graph/classification/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.utilities.flash_cli import FlashCLI from flash.graph import GraphClassificationData, GraphClassifier @@ -23,7 +22,7 @@ def from_tu_dataset( name: str = "KKI", val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> GraphClassificationData: """Downloads and loads the TU Dataset.""" diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py new file mode 100644 index 0000000000..a4b20a283e --- /dev/null +++ b/flash/image/classification/adapters.py @@ -0,0 +1,544 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from collections import defaultdict +from functools import partial +from typing import Any, Callable, List, Optional, Type + +import torch +from pytorch_lightning import LightningModule +from pytorch_lightning.plugins import DataParallelPlugin, DDPPlugin, DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache +from torch.utils.data import DataLoader, IterableDataset, Sampler + +import flash +from flash.core.adapter import Adapter, AdapterTask +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE +from flash.core.utilities.providers import _LEARN2LEARN +from flash.core.utilities.url_error import catch_url_error +from flash.image.classification.integrations.learn2learn import TaskDataParallel, TaskDistributedDataParallel + +warning_cache = WarningCache() + + +if _LEARN2LEARN_AVAILABLE: + import learn2learn as l2l + from learn2learn.data.transforms import RemapLabels as Learn2LearnRemapLabels +else: + + class Learn2LearnRemapLabels: + pass + + +class RemapLabels(Learn2LearnRemapLabels): + def remap(self, data, mapping): + # remap needs to be adapted to Flash API. + data[DefaultDataKeys.TARGET] = mapping(data[DefaultDataKeys.TARGET]) + return data + + +class NoModule: + + """This class is used to prevent nn.Module infinite recursion.""" + + def __init__(self, task): + self.task = task + + def __getattr__(self, key): + if key != "task": + return getattr(self.task, key) + return self.task + + def __setattr__(self, key: str, value: Any) -> None: + if key == "task": + object.__setattr__(self, key, value) + return + setattr(self.task, key, value) + + +class Model(torch.nn.Module): + def __init__(self, backbone: torch.nn.Module, head: Optional[torch.nn.Module]): + super().__init__() + self.backbone = backbone + self.head = head + + def forward(self, x): + x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) + if self.head is None: + return x + return self.head(x) + + +class Learn2LearnAdapter(Adapter): + + required_extras: str = "image" + + def __init__( + self, + task: AdapterTask, + backbone: torch.nn.Module, + head: torch.nn.Module, + algorithm_cls: Type[LightningModule], + ways: int, + shots: int, + meta_batch_size: int, + queries: int = 1, + num_task: int = -1, + epoch_length: Optional[int] = None, + test_epoch_length: Optional[int] = None, + test_ways: Optional[int] = None, + test_shots: Optional[int] = None, + test_queries: Optional[int] = None, + test_num_task: Optional[int] = None, + default_transforms_fn: Optional[Callable] = None, + seed: int = 42, + **algorithm_kwargs, + ): + """The ``Learn2LearnAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with `learn 2 + learn` library (https://github.com/learnables/learn2learn). + + Args: + task: Task to be used. This adapter should work with any Flash Classification task + backbone: Feature extractor to be used. + head: Predictive head. + algorithm_cls: Algorithm class coming + from: https://github.com/learnables/learn2learn/tree/master/learn2learn/algorithms/lightning + ways: Number of classes conserved for generating the task. + shots: Number of samples used for adaptation. + meta_batch_size: Number of task to be sampled and optimized over before doing a meta optimizer step. + queries: Number of samples used for computing the meta loss after the adaption on the `shots` samples. + num_task: Total number of tasks to be sampled during training. If -1, a new task will always be sampled. + epoch_length: Total number of tasks to be sampled to make an epoch. + test_ways: Number of classes conserved for generating the validation and testing task. + test_shots: Number of samples used for adaptation during validation and testing phase. + test_queries: Number of samples used for computing the meta loss during validation or testing + after the adaption on `shots` samples. + epoch_length: Total number of tasks to be sampled to make an epoch during validation and testing phase. + default_transforms_fn: A Callable to create the task transform. + The callable should take the dataset, ways and shots as arguments. + algorithm_kwargs: Keyword arguments to be provided to the algorithm class from learn2learn + """ + + super().__init__() + + self._task = NoModule(task) + self.backbone = backbone + self.head = head + self.algorithm_cls = algorithm_cls + self.meta_batch_size = meta_batch_size + + self.num_task = num_task + self.default_transforms_fn = default_transforms_fn + self.seed = seed + self.epoch_length = epoch_length or meta_batch_size + + self.ways = ways + self.shots = shots + self.queries = queries + + self.test_ways = test_ways or ways + self.test_shots = test_shots or shots + self.test_queries = test_queries or queries + self.test_num_task = test_num_task or num_task + self.test_epoch_length = test_epoch_length or self.epoch_length + + params = inspect.signature(self.algorithm_cls).parameters + + algorithm_kwargs["train_ways"] = ways + algorithm_kwargs["train_shots"] = shots + algorithm_kwargs["train_queries"] = queries + + algorithm_kwargs["test_ways"] = self.test_ways + algorithm_kwargs["test_shots"] = self.test_shots + algorithm_kwargs["test_queries"] = self.test_queries + + if "model" in params: + algorithm_kwargs["model"] = Model(backbone=backbone, head=head) + + if "features" in params: + algorithm_kwargs["features"] = Model(backbone=backbone, head=None) + + if "classifier" in params: + algorithm_kwargs["classifier"] = head + + self.model = self.algorithm_cls(**algorithm_kwargs) + + # this algorithm requires a special treatment + self._algorithm_has_validated = self.algorithm_cls != l2l.algorithms.LightningPrototypicalNetworks + + def _default_transform(self, dataset, ways: int, shots: int, queries) -> List[Callable]: + return [ + l2l.data.transforms.FusedNWaysKShots(dataset, n=ways, k=shots + queries), + l2l.data.transforms.LoadData(dataset), + RemapLabels(dataset), + l2l.data.transforms.ConsecutiveLabels(dataset), + ] + + @staticmethod + def _labels_to_indices(data): + out = defaultdict(list) + for idx, sample in enumerate(data): + label = sample[DefaultDataKeys.TARGET] + if torch.is_tensor(label): + label = label.item() + out[label].append(idx) + return out + + def _convert_dataset( + self, + trainer: "flash.Trainer", + dataset: BaseAutoDataset, + ways: int, + shots: int, + queries: int, + num_workers: int, + num_task: int, + epoch_length: int, + ): + if isinstance(dataset, BaseAutoDataset): + + metadata = getattr(dataset, "data", None) + if metadata is None or (metadata is not None and not isinstance(dataset.data, list)): + raise MisconfigurationException("Only dataset built out of metadata is supported.") + + labels_to_indices = self._labels_to_indices(dataset.data) + + if len(labels_to_indices) < ways: + raise MisconfigurationException( + "Provided `ways` should be lower or equal to number of classes within your dataset." + ) + + if min(len(indice) for indice in labels_to_indices.values()) < (shots + queries): + raise MisconfigurationException( + "Provided `shots + queries` should be lower than the lowest number of sample per class." + ) + + # convert the dataset to MetaDataset + dataset = l2l.data.MetaDataset(dataset, indices_to_labels=None, labels_to_indices=labels_to_indices) + + transform_fn = self.default_transforms_fn or self._default_transform + + taskset = l2l.data.TaskDataset( + dataset=dataset, + task_transforms=transform_fn(dataset, ways=ways, shots=shots, queries=queries), + num_tasks=num_task, + task_collate=self._identity_task_collate_fn, + ) + + if isinstance( + trainer.training_type_plugin, + ( + DDPPlugin, + DDPSpawnPlugin, + ), + ): + # when running in a distributed data parallel way, + # we are actually sampling one task per device. + dataset = TaskDistributedDataParallel( + taskset=taskset, + global_rank=trainer.global_rank, + world_size=trainer.world_size, + num_workers=num_workers, + epoch_length=epoch_length, + seed=os.getenv("PL_GLOBAL_SEED", self.seed), + requires_divisible=trainer.training, + ) + self.trainer.accumulated_grad_batches = self.meta_batch_size / trainer.world_size + else: + devices = 1 + if isinstance(trainer.training_type_plugin, DataParallelPlugin): + # when using DP, we need to sample n tasks, so it can splitted across multiple devices. + devices = trainer.accelerator_connector.devices + dataset = TaskDataParallel(taskset, epoch_length=epoch_length, devices=devices, collate_fn=None) + self.trainer.accumulated_grad_batches = self.meta_batch_size / devices + + return dataset + + @staticmethod + def _identity_task_collate_fn(x: Any) -> Any: + return x + + @classmethod + @catch_url_error + def from_task( + cls, + *args, + task: AdapterTask, + backbone: torch.nn.Module, + head: torch.nn.Module, + algorithm: Type[LightningModule], + **kwargs, + ) -> Adapter: + if "meta_batch_size" not in kwargs: + raise MisconfigurationException( + "The `meta_batch_size` should be provided as training_strategy_kwargs={'meta_batch_size'=...}. " + "This is equivalent to the epoch length." + ) + if "shots" not in kwargs: + raise MisconfigurationException( + "The `shots` should be provided training_strategy_kwargs={'shots'=...}. " + "This is equivalent to the number of sample per label to select within a task." + ) + return cls(task, backbone, head, algorithm, **kwargs) + + def training_step(self, batch, batch_idx) -> Any: + input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.model.training_step(input, batch_idx) + + def validation_step(self, batch, batch_idx): + # Should be True only for trainer.validate + if self.trainer.state.fn == TrainerFn.VALIDATING: + self._algorithm_has_validated = True + input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.model.validation_step(input, batch_idx) + + def validation_epoch_end(self, outpus: Any): + self.model.validation_epoch_end(outpus) + + def test_step(self, batch, batch_idx): + input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.model.test_step(input, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.model.predict_step(batch[DefaultDataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx) + + def _sanetize_batch_size(self, batch_size: int) -> int: + if batch_size != 1: + warning_cache.warn( + "When using a meta-learning training_strategy, the batch_size should be set to 1. " + "HINT: You can modify the `meta_batch_size` to 100 for example by doing " + f"{type(self._task.task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})" + ) + return 1 + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + trainer: "flash.Trainer", + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + dataset = self._convert_dataset( + trainer=trainer, + dataset=dataset, + ways=self.ways, + shots=self.shots, + queries=self.queries, + num_workers=num_workers, + num_task=self.num_task, + epoch_length=self.epoch_length, + ) + if isinstance(dataset, IterableDataset): + shuffle = False + sampler = None + return super().process_train_dataset( + dataset, + trainer, + self._sanetize_batch_size(batch_size), + num_workers, + False, + collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + trainer: "flash.Trainer", + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + dataset = self._convert_dataset( + trainer=trainer, + dataset=dataset, + ways=self.test_ways, + shots=self.test_shots, + queries=self.test_queries, + num_workers=num_workers, + num_task=self.test_num_task, + epoch_length=self.test_epoch_length, + ) + if isinstance(dataset, IterableDataset): + shuffle = False + sampler = None + return super().process_train_dataset( + dataset, + trainer, + self._sanetize_batch_size(batch_size), + num_workers, + False, + collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + trainer: "flash.Trainer", + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + dataset = self._convert_dataset( + trainer=trainer, + dataset=dataset, + ways=self.test_ways, + shots=self.test_shots, + queries=self.test_queries, + num_workers=num_workers, + num_task=self.test_num_task, + epoch_length=self.test_epoch_length, + ) + if isinstance(dataset, IterableDataset): + shuffle = False + sampler = None + return super().process_train_dataset( + dataset, + trainer, + self._sanetize_batch_size(batch_size), + num_workers, + False, + collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + + if not self._algorithm_has_validated: + raise MisconfigurationException( + "This training_strategies requires to be validated. Call trainer.validate(...)." + ) + + return super().process_predict_dataset( + dataset, + batch_size, + num_workers, + pin_memory, + collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + +class DefaultAdapter(Adapter): + """The ``DefaultAdapter`` is an :class:`~flash.core.adapter.Adapter`.""" + + required_extras: str = "image" + + def __init__(self, task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module): + super().__init__() + + self._task = NoModule(task) + self.backbone = backbone + self.head = head + + @classmethod + @catch_url_error + def from_task( + cls, + *args, + task: AdapterTask, + backbone: torch.nn.Module, + head: torch.nn.Module, + **kwargs, + ) -> Adapter: + return cls(task, backbone, head) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return Task.training_step(self._task.task, batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return Task.validation_step(self._task.task, batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return Task.test_step(self._task.task, batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch[DefaultDataKeys.PREDS] = Task.predict_step( + self._task.task, (batch[DefaultDataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + ) + return batch + + def forward(self, x) -> torch.Tensor: + # TODO: Resolve this hack + if x.dim() == 3: + x = x.unsqueeze(0) + x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) + return self.head(x) + + +TRAINING_STRATEGIES = FlashRegistry("training_strategies") +TRAINING_STRATEGIES(name="default", fn=partial(DefaultAdapter.from_task)) + +if _LEARN2LEARN_AVAILABLE: + from learn2learn import algorithms + + for algorithm in dir(algorithms): + # skip base class + if algorithm == "LightningEpisodicModule": + continue + try: + if "lightning" in algorithm.lower() and issubclass(getattr(algorithms, algorithm), LightningModule): + TRAINING_STRATEGIES( + name=algorithm.lower().replace("lightning", ""), + fn=partial(Learn2LearnAdapter.from_task, algorithm=getattr(algorithms, algorithm)), + providers=[_LEARN2LEARN], + ) + except Exception: + pass diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py index 6804c909f8..4056387b86 100644 --- a/flash/image/classification/cli.py +++ b/flash/image/classification/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -22,7 +21,7 @@ def from_hymenoptera( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> ImageClassificationData: """Downloads and loads the Hymenoptera (Ants, Bees) data set.""" @@ -38,7 +37,7 @@ def from_hymenoptera( def from_movie_posters( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> ImageClassificationData: """Downloads and loads the movie posters genre classification data set.""" diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index a482cc0e53..af389fc9ba 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -130,7 +130,7 @@ def from_data_frame( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": @@ -227,7 +227,7 @@ def from_csv( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": diff --git a/flash/image/classification/integrations/__init__.py b/flash/image/classification/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/image/classification/integrations/baal/__init__.py b/flash/image/classification/integrations/baal/__init__.py new file mode 100644 index 0000000000..309a9010c7 --- /dev/null +++ b/flash/image/classification/integrations/baal/__init__.py @@ -0,0 +1,2 @@ +from flash.image.classification.integrations.baal.data import ActiveLearningDataModule # noqa F401 +from flash.image.classification.integrations.baal.loop import ActiveLearningLoop # noqa F401 diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py new file mode 100644 index 0000000000..b4a439e9a7 --- /dev/null +++ b/flash/image/classification/integrations/baal/data.py @@ -0,0 +1,168 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data import DataLoader, Dataset, random_split + +from flash import DataModule +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_pipeline import DataPipeline +from flash.core.utilities.imports import _BAAL_AVAILABLE, requires + +if _BAAL_AVAILABLE: + from baal.active.dataset import ActiveLearningDataset + from baal.active.heuristics import AbstractHeuristic, BALD +else: + + class AbstractHeuristic: + pass + + class BALD(AbstractHeuristic): + pass + + +def dataset_to_non_labelled_tensor(dataset: BaseAutoDataset) -> torch.tensor: + return torch.zeros(len(dataset)) + + +def filter_unlabelled_data(dataset: BaseAutoDataset) -> Dataset: + return dataset + + +def train_val_split(dataset: Dataset, val_size: float = 0.1): + L = len(dataset) + train_size = int(L * (1 - val_size)) + val_size = L - train_size + return random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)) + + +class ActiveLearningDataModule(LightningDataModule): + @requires("baal") + def __init__( + self, + labelled: Optional[DataModule] = None, + heuristic: "AbstractHeuristic" = BALD(), + map_dataset_to_labelled: Optional[Callable] = dataset_to_non_labelled_tensor, + filter_unlabelled_data: Optional[Callable] = filter_unlabelled_data, + num_label_randomly: int = 5, + val_split: Optional[float] = None, + ): + """The `ActiveLearningDataModule` handles data manipulation for ActiveLearning. + + Args: + labelled: DataModule containing labelled train data for research use-case. + The labelled data would be masked. + heuristic: Sorting algorithm used to rank samples on how likely they can help with model performance. + map_dataset_to_labelled: Function used to emulate masking on labelled dataset. + filter_unlabelled_data: Function used to filter the unlabelled data while computing uncertainties. + num_label_randomly: Number of samples to randomly label from the uncertainty scores. + val_split: Float to split train dataset into train and validation set. + """ + super().__init__() + self.labelled = labelled + self.heuristic = heuristic + self.map_dataset_to_labelled = map_dataset_to_labelled + self.filter_unlabelled_data = filter_unlabelled_data + self.num_label_randomly = num_label_randomly + self.val_split = val_split + self._dataset: Optional[ActiveLearningDataset] = None + + if not self.labelled: + raise MisconfigurationException("The labelled `datamodule` should be provided.") + + if not self.labelled.num_classes: + raise MisconfigurationException("The labelled dataset should be labelled") + + if self.labelled and (self.labelled._val_ds is not None or self.labelled._predict_ds is not None): + raise MisconfigurationException("The labelled `datamodule` should have only train data.") + + self._dataset = ActiveLearningDataset( + self.labelled._train_ds, labelled=self.map_dataset_to_labelled(self.labelled._train_ds) + ) + + if not self.val_split or not self.has_labelled_data: + self.val_dataloader = None + elif self.val_split < 0 or self.val_split > 1: + raise MisconfigurationException("The `val_split` should a float between 0 and 1.") + + if self.labelled._test_ds: + self.test_dataloader = self._test_dataloader + + @property + def has_test(self) -> bool: + return self.labelled._test_ds is not None + + @property + def has_labelled_data(self) -> bool: + return self._dataset.n_labelled > 0 + + @property + def has_unlabelled_data(self) -> bool: + return self._dataset.n_unlabelled > 0 + + @property + def num_classes(self) -> Optional[int]: + return getattr(self.labelled, "num_classes", None) or getattr(self.unlabelled, "num_classes", None) + + @property + def data_pipeline(self) -> "DataPipeline": + return self.labelled.data_pipeline + + def train_dataloader(self) -> "DataLoader": + if self.val_split: + self.labelled._train_ds = train_val_split(self._dataset, self.val_split)[0] + else: + self.labelled._train_ds = self._dataset + + if self.has_labelled_data and self.val_split: + self.val_dataloader = self._val_dataloader + + return self.labelled.train_dataloader() + + def _val_dataloader(self) -> "DataLoader": + self.labelled._val_ds = train_val_split(self._dataset, self.val_split)[1] + return self.labelled._val_dataloader() + + def _test_dataloader(self) -> "DataLoader": + return self.labelled.test_dataloader() + + def predict_dataloader(self) -> "DataLoader": + self.labelled._train_ds = self.filter_unlabelled_data(self._dataset.pool) + return self.labelled.train_dataloader() + + def label(self, probabilities: List[torch.Tensor] = None, indices=None): + if probabilities is not None and indices: + raise MisconfigurationException( + "The `probabilities` and `indices` are mutually exclusive, pass only of one them." + ) + if probabilities is not None: + uncertainties = self.heuristic.get_uncertainties(torch.cat(probabilities, dim=0)) + indices = np.argsort(uncertainties) + if self._dataset is not None: + unlabelled_mask = self._dataset.labelled == False # noqa E712 + unlabelled = self._dataset.labelled[unlabelled_mask] + unlabelled[indices[-self.num_label_randomly :]] = True + self._dataset.labelled[unlabelled_mask] = unlabelled + else: + self._dataset.label_randomly(self.num_label_randomly) + + def state_dict(self) -> Dict[str, torch.Tensor]: + return self._dataset.state_dict() + + def load_state_dict(self, state_dict) -> None: + return self._dataset.load_state_dict(state_dict) diff --git a/flash/image/classification/integrations/baal/dropout.py b/flash/image/classification/integrations/baal/dropout.py new file mode 100644 index 0000000000..02ed7361ed --- /dev/null +++ b/flash/image/classification/integrations/baal/dropout.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +import flash +from flash.core.utilities.imports import _BAAL_AVAILABLE + +if _BAAL_AVAILABLE: + from baal.bayesian.dropout import _patch_dropout_layers + + +class InferenceMCDropoutTask(flash.Task): + def __init__(self, module: flash.Task, inference_iteration: int): + super().__init__() + self.parent_module = module + self.trainer = module.trainer + changed = _patch_dropout_layers(self.parent_module) + if not changed: + raise MisconfigurationException("The model should contain at least 1 dropout layer.") + self.inference_iteration = inference_iteration + + def predict_step(self, batch, batch_idx, dataloader_idx: int = 0): + out = [] + for _ in range(self.inference_iteration): + out.append(self.parent_module.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)) + + # BaaL expects a shape [num_samples, num_classes, num_iterations] + return torch.tensor(out).permute((1, 2, 0)) diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py new file mode 100644 index 0000000000..f4b2d7f6fd --- /dev/null +++ b/flash/image/classification/integrations/baal/loop.py @@ -0,0 +1,150 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from typing import Any, Dict, Optional + +import torch +from pytorch_lightning.loops import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader +from pytorch_lightning.trainer.progress import Progress +from pytorch_lightning.trainer.states import RunningStage, TrainerFn + +import flash +from flash.core.data.utils import _STAGES_PREFIX +from flash.core.utilities.imports import requires +from flash.image.classification.integrations.baal.data import ActiveLearningDataModule +from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask + + +class ActiveLearningLoop(Loop): + @requires("baal") + def __init__(self, label_epoch_frequency: int, inference_iteration: int = 2, should_reset_weights: bool = True): + """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the + `ActiveLearningTrainer` + + Example:: + + while unlabelled data or budget critera not reached: + + if labelled data + trainer.fit(model, labelled data) + + if unlabelled data: + predictions = trainer.predict(model, unlabelled data) + uncertainties = heuristic(predictions) + request labellelisation for the sample with highest uncertainties under a given budget + + Args: + label_epoch_frequency: Number of epoch to train on before requesting labellisation. + inference_iteration: Number of inference to perform to compute uncertainty. + """ + super().__init__() + self.label_epoch_frequency = label_epoch_frequency + self.inference_iteration = inference_iteration + self.should_reset_weights = should_reset_weights + self.fit_loop: Optional[FitLoop] = None + self.progress = Progress() + self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None + self._lightning_module: Optional[flash.Task] = None + + @property + def done(self) -> bool: + return self.progress.current.completed >= self.max_epochs + + def connect(self, fit_loop: FitLoop): + self.fit_loop = fit_loop + self.max_epochs = self.fit_loop.max_epochs + self.fit_loop.max_epochs = self.label_epoch_frequency + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + assert isinstance(self.trainer.datamodule, ActiveLearningDataModule) + self.trainer.predict_loop._return_predictions = True + self._lightning_module = self.trainer.lightning_module + self._model_state_dict = deepcopy(self._lightning_module.state_dict()) + self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration) + + def reset(self) -> None: + pass + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + if self.trainer.datamodule.has_labelled_data: + self._reset_dataloader_for_stage(RunningStage.TRAINING) + self._reset_dataloader_for_stage(RunningStage.VALIDATING) + if self.trainer.datamodule.has_unlabelled_data: + self._reset_dataloader_for_stage(RunningStage.PREDICTING) + self.progress.increment_ready() + + def advance(self, *args: Any, **kwargs: Any) -> None: + self.progress.increment_started() + + if self.trainer.datamodule.has_labelled_data: + self.fit_loop.run() + + if self.trainer.datamodule.has_test: + self.trainer.test_loop.run() + + if self.trainer.datamodule.has_unlabelled_data: + self._reset_predicting() + probabilities = self.trainer.predict_loop.run() + self.trainer.datamodule.label(probabilities=probabilities) + else: + raise StopIteration + + self._reset_fitting() + self.progress.increment_processed() + + def on_advance_end(self) -> None: + if self.trainer.datamodule.has_unlabelled_data and self.should_reset_weights: + # reload the weights to retrain from scratch with the new labelled data. + self._lightning_module.load_state_dict(self._model_state_dict) + self.progress.increment_completed() + return super().on_advance_end() + + def on_run_end(self): + self._reset_fitting() + return super().on_run_end() + + def on_save_checkpoint(self) -> Dict: + return {"datamodule_state_dict": self.trainer.datamodule.state_dict()} + + def on_load_checkpoint(self, state_dict) -> None: + self.trainer.datamodule.load_state_dict(state_dict.pop("datamodule_state_dict")) + + def __getattr__(self, key): + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + def _reset_fitting(self): + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + self.trainer.lightning_module.on_train_dataloader() + self.trainer.accelerator.connect(self._lightning_module) + + def _reset_predicting(self): + self.trainer.state.fn = TrainerFn.PREDICTING + self.trainer.predicting = True + self.trainer.lightning_module.on_predict_dataloader() + self.trainer.accelerator.connect(self.inference_model) + + def _reset_dataloader_for_stage(self, running_state: RunningStage): + dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" + setattr( + self.trainer.lightning_module, + dataloader_name, + _PatchDataLoader(getattr(self.trainer.datamodule, dataloader_name)(), running_state), + ) + setattr(self.trainer, dataloader_name, None) + getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module) diff --git a/flash/image/classification/integrations/learn2learn.py b/flash/image/classification/integrations/learn2learn.py new file mode 100644 index 0000000000..255da82506 --- /dev/null +++ b/flash/image/classification/integrations/learn2learn.py @@ -0,0 +1,147 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Note: This file will be deleted once +https://github.com/learnables/learn2learn/pull/257/files is merged within Learn2Learn. +""" + +from typing import Any, Callable, Optional + +import pytorch_lightning as pl +from torch.utils.data import IterableDataset +from torch.utils.data._utils.collate import default_collate +from torch.utils.data._utils.worker import get_worker_info + +from flash.core.utilities.imports import requires + + +class TaskDataParallel(IterableDataset): + @requires("learn2learn") + def __init__( + self, + tasks: Any, + epoch_length: int, + devices: int = 1, + collate_fn: Optional[Callable] = default_collate, + ): + """This class is used to sample epoch_length tasks to represent an epoch. + + It should be used when using DataParallel + + Args: + taskset: Dataset used to sample task. + epoch_length: The expected epoch length. This requires to be divisible by devices. + devices: Number of devices being used. + collate_fn: The collate_fn to be applied on multiple tasks + """ + self.tasks = tasks + self.epoch_length = epoch_length + self.devices = devices + + if epoch_length % devices != 0: + raise Exception("The `epoch_length` should be the number of `devices`.") + + self.collate_fn = collate_fn + self.counter = 0 + + def __iter__(self): + self.counter = 0 + return self + + def __next__(self): + if self.counter >= len(self): + raise StopIteration + self.counter += self.devices + tasks = [] + for _ in range(self.devices): + for item in self.tasks.sample(): + tasks.append(item) + if self.collate_fn: + tasks = self.collate_fn(tasks) + return tasks + + def __len__(self): + return self.epoch_length + + +class TaskDistributedDataParallel(IterableDataset): + @requires("learn2learn") + def __init__( + self, + taskset: Any, + global_rank: int, + world_size: int, + num_workers: int, + epoch_length: int, + seed: int, + requires_divisible: bool = True, + ): + """This class is used to sample tasks in a distributed setting such as DDP with multiple workers. + + This won't work as expected if `num_workers = 0` and several dataloaders + are being iterated on at the same time. + + Args: + taskset: Dataset used to sample task. + global_rank: Rank of the current process. + world_size: Total of number of processes. + num_workers: Number of workers to be provided to the DataLoader. + epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). + seed: The seed will be used on __iter__ call and should be the same for all processes. + """ + self.taskset = taskset + self.global_rank = global_rank + self.world_size = world_size + self.num_workers = 1 if num_workers == 0 else num_workers + self.worker_world_size = self.world_size * self.num_workers + self.epoch_length = epoch_length + self.seed = seed + self.iteration = 0 + self.iteration = 0 + self.requires_divisible = requires_divisible + self.counter = 0 + + if requires_divisible and epoch_length % self.worker_world_size != 0: + raise Exception("The `epoch_length` should be divisible by `world_size`.") + + def __len__(self) -> int: + return self.epoch_length // self.world_size + + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + @property + def worker_rank(self) -> int: + is_global_zero = self.global_rank == 0 + return self.global_rank + self.worker_id + int(not is_global_zero and self.num_workers > 1) + + def __iter__(self): + self.iteration += 1 + self.counter = 0 + pl.seed_everything(self.seed + self.iteration) + return self + + def __next__(self): + if self.counter >= len(self): + raise StopIteration + task_descriptions = [] + for _ in range(self.worker_world_size): + task_descriptions.append(self.taskset.sample_task_description()) + + data = self.taskset.get_task(task_descriptions[self.worker_rank]) + self.counter += 1 + return data diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 89071ad71c..a81be9c45a 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -15,18 +15,19 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler from torchmetrics import Metric -from flash.core.classification import ClassificationTask, Labels -from flash.core.data.data_source import DefaultDataKeys +from flash.core.classification import ClassificationAdapterTask, Labels from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry +from flash.image.classification.adapters import TRAINING_STRATEGIES from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES -class ImageClassifier(ClassificationTask): +class ImageClassifier(ClassificationAdapterTask): """The ``ImageClassifier`` is a :class:`~flash.Task` for classifying images. For more details, see :ref:`image_classification`. The ``ImageClassifier`` also supports multi-label classification with ``multi_label=True``. For more details, see :ref:`image_classification_multi_label`. @@ -68,12 +69,13 @@ def fn_resnet(pretrained: bool = True): """ backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES + training_strategies: FlashRegistry = TRAINING_STRATEGIES required_extras: str = "image" def __init__( self, - num_classes: int, + num_classes: Optional[int] = None, backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, @@ -87,59 +89,61 @@ def __init__( learning_rate: float = 1e-3, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + training_strategy: Optional[str] = "default", + training_strategy_kwargs: Optional[Dict[str, Any]] = None, ): - super().__init__( - num_classes=num_classes, - model=None, - loss_fn=loss_fn, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - metrics=metrics, - learning_rate=learning_rate, - multi_label=multi_label, - serializer=serializer or Labels(multi_label=multi_label), - ) self.save_hyperparameters() if not backbone_kwargs: backbone_kwargs = {} + if not training_strategy_kwargs: + training_strategy_kwargs = {} + + if training_strategy == "default": + if not num_classes: + raise MisconfigurationException("`num_classes` should be provided.") + else: + num_classes = training_strategy_kwargs.get("ways", None) + if not num_classes: + raise MisconfigurationException( + "`training_strategy_kwargs` should contain `ways`, `meta_batch_size` and `shots`." + ) + if isinstance(backbone, tuple): - self.backbone, num_features = backbone + backbone, num_features = backbone else: - self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential( + head = head or nn.Sequential( nn.Linear(num_features, num_classes), ) - def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().validation_step(batch, batch_idx) - - def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().test_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = super().predict_step( - (batch[DefaultDataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + adapter_from_class = self.training_strategies.get(training_strategy) + adapter = adapter_from_class( + task=self, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **training_strategy_kwargs, ) - return batch - def forward(self, x) -> torch.Tensor: - x = self.backbone(x) - if x.dim() == 4: - x = x.mean(-1).mean(-1) - return self.head(x) + super().__init__( + adapter, + num_classes=num_classes, + loss_fn=loss_fn, + metrics=metrics, + learning_rate=learning_rate, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + multi_label=multi_label, + serializer=serializer or Labels(multi_label=multi_label), + ) @classmethod def available_pretrained_weights(cls, backbone: str): diff --git a/flash/image/classification/transforms.py b/flash/image/classification/transforms.py index 3b5ba98a4c..738823a56e 100644 --- a/flash/image/classification/transforms.py +++ b/flash/image/classification/transforms.py @@ -19,7 +19,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires if _KORNIA_AVAILABLE: import kornia as K @@ -28,6 +28,21 @@ import torchvision from torchvision import transforms as T +if _ALBUMENTATIONS_AVAILABLE: + import albumentations + + +class AlbumentationsAdapter(torch.nn.Module): + @requires("albumentations") + def __init__(self, transform): + super().__init__() + if not isinstance(transform, list): + transform = [transform] + self.transform = albumentations.Compose(transform) + + def forward(self, x): + return torch.from_numpy(self.transform(image=x.numpy())["image"]) + def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for image classification: resize the image, convert the image and target to a tensor, diff --git a/flash/image/data.py b/flash/image/data.py index 35d37281a5..5d0eb9cbe5 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +from collections import defaultdict from io import BytesIO from pathlib import Path from typing import Any, Dict, Optional @@ -71,6 +72,16 @@ def example_input(self) -> str: return base64.b64encode(f.read()).decode("UTF-8") +def _labels_to_indices(data): + out = defaultdict(list) + for idx, sample in enumerate(data): + label = sample[DefaultDataKeys.TARGET] + if torch.is_tensor(label): + label = label.item() + out[label].append(idx) + return out + + class ImagePathsDataSource(PathsDataSource): def __init__(self): super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) diff --git a/flash/image/detection/cli.py b/flash/image/detection/cli.py index 8c2eb0c3d1..f955e34bbe 100644 --- a/flash/image/detection/cli.py +++ b/flash/image/detection/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -23,7 +22,7 @@ def from_coco_128( val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> ObjectDetectionData: """Downloads and loads the COCO 128 data set.""" diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9b00375d99..9a7e5c31fa 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -200,7 +200,7 @@ def from_coco( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders @@ -279,7 +279,7 @@ def from_voc( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders @@ -358,7 +358,7 @@ def from_via( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders diff --git a/flash/image/embedding/backbones/__init__.py b/flash/image/embedding/backbones/__init__.py new file mode 100644 index 0000000000..7781040e63 --- /dev/null +++ b/flash/image/embedding/backbones/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.backbones.vissl_backbones import register_vissl_backbones # noqa: F401 + +IMAGE_EMBEDDER_BACKBONES = FlashRegistry("embedder_backbones") +register_vissl_backbones(IMAGE_EMBEDDER_BACKBONES) diff --git a/flash/image/embedding/backbones/vissl_backbones.py b/flash/image/embedding/backbones/vissl_backbones.py new file mode 100644 index 0000000000..18bb214efe --- /dev/null +++ b/flash/image/embedding/backbones/vissl_backbones.py @@ -0,0 +1,117 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from vissl.config.attr_dict import AttrDict + from vissl.models.model_helpers import RESNET_NORM_LAYER + from vissl.models.trunks import MODEL_TRUNKS_REGISTRY + + from flash.image.embedding.vissl.adapter import VISSLAdapter +else: + RESNET_NORM_LAYER = object + + +def vision_transformer( + image_size: int = 224, + patch_size: int = 16, + hidden_dim: int = 384, + num_layers: int = 12, + num_heads: int = 6, + mlp_dim: int = 1532, + dropout_rate: float = 0, + attention_dropout_rate: float = 0, + drop_path_rate: float = 0, + qkv_bias: bool = True, + qk_scale: bool = False, + classifier: str = "token", + **kwargs, +) -> nn.Module: + + cfg = VISSLAdapter.get_model_config_template() + cfg.TRUNK = AttrDict( + { + "NAME": "vision_transformer", + "VISION_TRANSFORMERS": AttrDict( + { + "IMAGE_SIZE": image_size, + "PATCH_SIZE": patch_size, + "HIDDEN_DIM": hidden_dim, + "NUM_LAYERS": num_layers, + "NUM_HEADS": num_heads, + "MLP_DIM": mlp_dim, + "DROPOUT_RATE": dropout_rate, + "ATTENTION_DROPOUT_RATE": attention_dropout_rate, + "DROP_PATH_RATE": drop_path_rate, + "QKV_BIAS": qkv_bias, + "QK_SCALE": qk_scale, + "CLASSIFIER": classifier, + } + ), + } + ) + + trunk = MODEL_TRUNKS_REGISTRY["vision_transformer"](cfg, model_name="vision_transformer") + trunk.model_config = cfg + + return trunk, trunk.num_features + + +def resnet( + depth: int = 50, + width_multiplier: int = 1, + norm: RESNET_NORM_LAYER = None, + groupnorm_groups: int = 32, + standardize_convolutions: bool = False, + groups: int = 1, + zero_init_residual: bool = False, + width_per_group: int = 64, + layer4_stride: int = 2, + **kwargs, +) -> nn.Module: + if norm is None: + norm = RESNET_NORM_LAYER.BatchNorm + cfg = VISSLAdapter.get_model_config_template() + cfg.TRUNK = AttrDict( + { + "NAME": "resnet", + "RESNETS": AttrDict( + { + "DEPTH": depth, + "WIDTH_MULTIPLIER": width_multiplier, + "NORM": norm, + "GROUPNORM_GROUPS": groupnorm_groups, + "STANDARDIZE_CONVOLUTIONS": standardize_convolutions, + "GROUPS": groups, + "ZERO_INIT_RESIDUAL": zero_init_residual, + "WIDTH_PER_GROUP": width_per_group, + "LAYER4_STRIDE": layer4_stride, + } + ), + } + ) + + trunk = MODEL_TRUNKS_REGISTRY["resnet"](cfg, model_name="resnet") + trunk.model_config = cfg + + return trunk, 2048 + + +def register_vissl_backbones(register: FlashRegistry): + if _VISSL_AVAILABLE: + for backbone in (vision_transformer, resnet): + register(backbone) diff --git a/flash/image/embedding/heads/__init__.py b/flash/image/embedding/heads/__init__.py new file mode 100644 index 0000000000..0afd7bc39d --- /dev/null +++ b/flash/image/embedding/heads/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401 + +IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads") +register_vissl_heads(IMAGE_EMBEDDER_HEADS) diff --git a/flash/image/embedding/heads/vissl_heads.py b/flash/image/embedding/heads/vissl_heads.py new file mode 100644 index 0000000000..5fb7817f67 --- /dev/null +++ b/flash/image/embedding/heads/vissl_heads.py @@ -0,0 +1,165 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import torch +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from vissl.config.attr_dict import AttrDict + from vissl.models.heads import MODEL_HEADS_REGISTRY, register_model_head + + from flash.image.embedding.vissl.adapter import VISSLAdapter +else: + AttrDict = object + + +class SimCLRHead(nn.Module): + """VISSL adpots a complicated config input to create an MLP. + + This class simplifies the standard SimCLR projection head. + Can be configured to be used with barlow twins and moco as well. + + Returns MLP according to dimensions provided as a list. + linear-layer -> batch-norm (if flag) -> Relu -> ... + + Args: + model_config: Model config AttrDict from VISSL + dims: list of dimensions for creating a projection head + use_bn: use batch-norm after each linear layer or not + """ + + def __init__( + self, + model_config: AttrDict, + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + **kwargs, + ) -> nn.Module: + super().__init__() + + self.model_config = model_config + self.dims = dims + self.use_bn = use_bn + + self.clf = self.create_mlp() + + def create_mlp(self): + layers = [] + last_dim = self.dims[0] + + for dim in self.dims[1:-1]: + layers.append(nn.Linear(last_dim, dim)) + + if self.use_bn: + layers.append( + nn.BatchNorm1d( + dim, + eps=self.model_config.HEAD.BATCHNORM_EPS, + momentum=self.model_config.HEAD.BATCHNORM_MOMENTUM, + ) + ) + + layers.append(nn.ReLU(inplace=True)) + last_dim = dim + + layers.append(nn.Linear(last_dim, self.dims[-1])) + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.clf(x) + + +if _VISSL_AVAILABLE: + SimCLRHead = register_model_head("simclr_head")(SimCLRHead) + + +def simclr_head( + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + **kwargs, +) -> nn.Module: + cfg = VISSLAdapter.get_model_config_template() + head_kwargs = { + "dims": dims, + "use_bn": use_bn, + } + + cfg.HEAD.PARAMS.append(["simclr_head", head_kwargs]) + + head = MODEL_HEADS_REGISTRY["simclr_head"](cfg, **head_kwargs) + head.model_config = cfg + + return head + + +def swav_head( + dims: List[int] = [2048, 2048, 128], + use_bn: bool = True, + num_clusters: Union[int, List[int]] = [3000], + use_bias: bool = True, + return_embeddings: bool = True, + skip_last_bn: bool = True, + normalize_feats: bool = True, + activation_name: str = "ReLU", + use_weight_norm_prototypes: bool = False, + **kwargs, +) -> nn.Module: + cfg = VISSLAdapter.get_model_config_template() + head_kwargs = { + "dims": dims, + "use_bn": use_bn, + "num_clusters": [num_clusters] if isinstance(num_clusters, int) else num_clusters, + "use_bias": use_bias, + "return_embeddings": return_embeddings, + "skip_last_bn": skip_last_bn, + "normalize_feats": normalize_feats, + "activation_name": activation_name, + "use_weight_norm_prototypes": use_weight_norm_prototypes, + } + + cfg.HEAD.PARAMS.append(["swav_head", head_kwargs]) + + head = MODEL_HEADS_REGISTRY["swav_head"](cfg, **head_kwargs) + head.model_config = cfg + + return head + + +def barlow_twins_head(**kwargs) -> nn.Module: + return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs) + + +def moco_head(**kwargs) -> nn.Module: + return simclr_head(**kwargs) + + +def dino_head(**kwargs) -> nn.Module: + return swav_head( + dims=[384, 2048, 2048, 256], + use_bn=False, + return_embeddings=False, + activation_name="GELU", + num_clusters=[65536], + use_weight_norm_prototypes=True, + **kwargs, + ) + + +def register_vissl_heads(register: FlashRegistry): + for ssl_head in (swav_head, simclr_head, moco_head, dino_head, barlow_twins_head): + register(ssl_head) diff --git a/flash/image/embedding/losses/__init__.py b/flash/image/embedding/losses/__init__.py new file mode 100644 index 0000000000..71c0717e21 --- /dev/null +++ b/flash/image/embedding/losses/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.losses.vissl_losses import register_vissl_losses # noqa: F401 + +IMAGE_EMBEDDER_LOSS_FUNCTIONS = FlashRegistry("embedder_losses") +register_vissl_losses(IMAGE_EMBEDDER_LOSS_FUNCTIONS) diff --git a/flash/image/embedding/losses/vissl_losses.py b/flash/image/embedding/losses/vissl_losses.py new file mode 100644 index 0000000000..87dcf5260c --- /dev/null +++ b/flash/image/embedding/losses/vissl_losses.py @@ -0,0 +1,171 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + import vissl.losses # noqa: F401 + from classy_vision.losses import ClassyLoss, LOSS_REGISTRY + from vissl.config.attr_dict import AttrDict +else: + AttrDict = object + ClassyLoss = object + + +def get_loss_fn(loss_name: str, cfg: AttrDict): + loss_fn = LOSS_REGISTRY[loss_name](cfg) + loss_fn.__dict__["loss_name"] = loss_name + + return loss_fn + + +def dino_loss( + num_crops: int = 10, + momentum: float = 0.996, + student_temp: float = 0.1, + teacher_temp_min: float = 0.04, + teacher_temp_max: float = 0.07, + teacher_temp_warmup_iters: int = 37530, # convert this to 30 epochs + crops_for_teacher: List[int] = [0, 1], + ema_center: float = 0.9, + normalize_last_layer: bool = False, + output_dim: int = 65536, + **kwargs, +) -> ClassyLoss: + loss_name = "dino_loss" + cfg = AttrDict( + { + "num_crops": num_crops, + "momentum": momentum, + "student_temp": student_temp, + "teacher_temp_min": teacher_temp_min, + "teacher_temp_max": teacher_temp_max, + "teacher_temp_warmup_iters": teacher_temp_warmup_iters, + "crops_for_teacher": crops_for_teacher, + "ema_center": ema_center, + "normalize_last_layer": normalize_last_layer, + "output_dim": output_dim, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def swav_loss( + embedding_dim: int = 128, + temperature: float = 0.1, + use_double_precision: bool = False, + normalize_last_layer: bool = True, + num_iters: int = 3, + epsilon: float = 0.05, + num_crops: int = 8, + crops_for_assign: List[int] = [0, 1], + num_prototypes: Union[int, List[int]] = 3000, + temp_hard_assignment_iters: int = 0, + output_dir: str = ".", + queue_length: int = 0, + start_iter: int = 0, + local_queue_length: int = 0, +) -> ClassyLoss: + loss_name = "swav_loss" + cfg = AttrDict( + { + "embedding_dim": embedding_dim, + "temperature": temperature, + "use_double_precision": use_double_precision, + "normalize_last_layer": normalize_last_layer, + "num_iters": num_iters, + "epsilon": epsilon, + "num_crops": num_crops, + "crops_for_assign": crops_for_assign, + "num_prototypes": [num_prototypes] if isinstance(num_prototypes, int) else num_prototypes, + "temp_hard_assignment_iters": temp_hard_assignment_iters, + "output_dir": output_dir, + "queue": AttrDict( + { + "queue_length": queue_length, + "start_iter": start_iter, + "local_queue_length": local_queue_length, + } + ), + } + ) + + return get_loss_fn(loss_name, cfg) + + +def barlow_twins_loss( + lambda_: float = 0.0051, scale_loss: float = 0.024, latent_embedding_dim: int = 8192 +) -> ClassyLoss: + loss_name = "barlow_twins_loss" + cfg = AttrDict( + { + "lambda_": lambda_, + "scale_loss": scale_loss, + "embedding_dim": latent_embedding_dim, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def simclr_loss( + temperature: float = 0.1, + embedding_dim: int = 128, + effective_batch_size: int = 1, # set by setup training hook + world_size: int = 1, # set by setup training hook +) -> ClassyLoss: + loss_name = "simclr_info_nce_loss" + cfg = AttrDict( + { + "temperature": temperature, + "buffer_params": AttrDict( + { + "world_size": world_size, + "embedding_dim": embedding_dim, + "effective_batch_size": effective_batch_size, + } + ), + } + ) + + return get_loss_fn(loss_name, cfg) + + +def moco_loss( + embedding_dim: int = 128, + queue_size: int = 65536, + momentum: float = 0.999, + temperature: int = 0.2, + shuffle_batch: bool = True, +) -> ClassyLoss: + loss_name = "moco_loss" + cfg = AttrDict( + { + "embedding_dim": embedding_dim, + "queue_size": queue_size, + "momentum": momentum, + "temperature": temperature, + "shuffle_batch": shuffle_batch, + } + ) + + return get_loss_fn(loss_name, cfg) + + +def register_vissl_losses(register: FlashRegistry): + for loss_fn in (dino_loss, swav_loss, barlow_twins_loss, simclr_loss, moco_loss): + register(loss_fn) diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index c803757ec5..6bcb496a01 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -11,130 +11,139 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +import warnings +from typing import Any, Dict, List, Optional, Type, Union import torch -from pytorch_lightning.utilities import rank_zero_warn -from torch import nn -from torch.nn import functional as F from torch.optim.lr_scheduler import _LRScheduler -from torchmetrics import Accuracy, Metric +from flash.core.adapter import AdapterTask from flash.core.data.data_source import DefaultDataKeys -from flash.core.model import Task +from flash.core.data.states import CollateFn, PostTensorTransform, PreTensorTransform, ToTensorTransform +from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.isinstance import _isinstance -from flash.image.classification.data import ImageClassificationPreprocess +from flash.core.utilities.imports import _VISSL_AVAILABLE -if _IMAGE_AVAILABLE: - from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +if _VISSL_AVAILABLE: + import classy_vision + import classy_vision.generic.distributed_util + + from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES + from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES + from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS + + # patch this to avoid classy vision/vissl based distributed training + classy_vision.generic.distributed_util.get_world_size = lambda: 1 else: - IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + IMAGE_EMBEDDER_BACKBONES = FlashRegistry("backbones") + IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") + IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms") -class ImageEmbedder(Task): +class ImageEmbedder(AdapterTask): """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more details, see :ref:`image_embedder`. Args: - embedding_dim: Dimension of the embedded vector. ``None`` uses the default from the backbone. - backbone: A model to use to extract image features, defaults to ``"swav-imagenet"``. - pretrained: Use a pretrained backbone, defaults to ``True``. - loss_fn: Loss function for training and finetuning, defaults to :func:`torch.nn.functional.cross_entropy` + training_strategy: Training strategy from VISSL, + select between 'simclr', 'swav', 'dino', 'moco', or 'barlow_twins'. + head: projection head used for task, select between + 'simclr_head', 'swav_head', 'dino_head', 'moco_head', or 'barlow_twins_head'. + pretraining_transform: transform applied to input image for pre-training SSL model. + Select between 'simclr_transform', 'swav_transform', 'dino_transform', + 'moco_transform', or 'barlow_twins_transform'. + backbone: VISSL backbone, defaults to ``resnet``. + pretrained: Use a pretrained backbone, defaults to ``False``. optimizer: Optimizer to use for training and finetuning, defaults to :class:`torch.optim.SGD`. optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). scheduler: The scheduler or scheduler class to use. scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). - metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` - package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict - containing a combination of the aforementioned. In all cases, each metric needs to have the signature - `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. - pooling_fn: Function used to pool image to generate embeddings, defaults to :func:`torch.max`. + backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``. + training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks. + pretraining_transform_kwargs: arguments passed to VISSL transforms. """ - backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES + training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES + backbones: FlashRegistry = IMAGE_EMBEDDER_BACKBONES + transforms: FlashRegistry = IMAGE_EMBEDDER_TRANSFORMS required_extras: str = "image" def __init__( self, - embedding_dim: Optional[int] = None, - backbone: str = "resnet101", - pretrained: bool = True, - loss_fn: Callable = F.cross_entropy, + training_strategy: str, + head: str, + pretraining_transform: str, + backbone: str = "resnet", + pretrained: bool = False, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, - pooling_fn: Callable = torch.max, + backbone_kwargs: Optional[Dict[str, Any]] = None, + training_strategy_kwargs: Optional[Dict[str, Any]] = None, + pretraining_transform_kwargs: Optional[Dict[str, Any]] = None, ): - super().__init__( - model=None, + self.save_hyperparameters() + + if backbone_kwargs is None: + backbone_kwargs = {} + + if training_strategy_kwargs is None: + training_strategy_kwargs = {} + + backbone, _ = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + + metadata = self.training_strategies.get(training_strategy, with_metadata=True) + loss_fn, head, hooks = metadata["fn"](head=head, **training_strategy_kwargs) + + adapter = metadata["metadata"]["adapter"].from_task( + self, loss_fn=loss_fn, + backbone=backbone, + head=head, + hooks=hooks, + ) + + super().__init__( + adapter=adapter, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics, learning_rate=learning_rate, - preprocess=ImageClassificationPreprocess(), ) - self.save_hyperparameters() - self.backbone_name = backbone - self.embedding_dim = embedding_dim - assert pooling_fn in [torch.mean, torch.max] - self.pooling_fn = pooling_fn - - self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained) - - if embedding_dim is None: - self.head = nn.Identity() - else: - self.head = nn.Sequential( - nn.Flatten(), - nn.Linear(num_features, embedding_dim), - ) - rank_zero_warn("Adding linear layer on top of backbone. Remember to finetune first before using!") - - def apply_pool(self, x): - x = self.pooling_fn(x, dim=-1) - if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): - x = x[0] - x = self.pooling_fn(x, dim=-1) - if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): - x = x[0] - return x - - def forward(self, x) -> torch.Tensor: - x = self.backbone(x) - - # bolts ssl models return lists - if isinstance(x, tuple): - x = x[-1] - - if x.dim() == 4 and not self.embedding_dim: - x = self.apply_pool(x) - - x = self.head(x) - return x - - def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().validation_step(batch, batch_idx) - - def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return super().test_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = batch[DefaultDataKeys.INPUT] - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + transform, collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) + to_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + transform, + ) + + self.adapter.set_state(CollateFn(collate_fn)) + self.adapter.set_state(ToTensorTransform(to_tensor_transform)) + self.adapter.set_state(PostTensorTransform(None)) + self.adapter.set_state(PreTensorTransform(None)) + + warnings.warn( + "Warning: VISSL ImageEmbedder overrides any user provided transforms" + " with pre-defined transforms for the training strategy." + ) + + def on_train_start(self) -> None: + self.adapter.on_train_start() + + def on_train_epoch_end(self) -> None: + self.adapter.on_train_epoch_end() + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + + @classmethod + def available_training_strategies(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None) + if registry is None: + return [] + return registry.available_keys() diff --git a/flash/image/embedding/strategies/__init__.py b/flash/image/embedding/strategies/__init__.py new file mode 100644 index 0000000000..8d010d7bb8 --- /dev/null +++ b/flash/image/embedding/strategies/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.strategies.vissl_strategies import register_vissl_strategies # noqa: F401 + +IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") +register_vissl_strategies(IMAGE_EMBEDDER_STRATEGIES) diff --git a/flash/image/embedding/strategies/vissl_strategies.py b/flash/image/embedding/strategies/vissl_strategies.py new file mode 100644 index 0000000000..2622d7ae5b --- /dev/null +++ b/flash/image/embedding/strategies/vissl_strategies.py @@ -0,0 +1,70 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.core.utilities.providers import _VISSL +from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS +from flash.image.embedding.losses import IMAGE_EMBEDDER_LOSS_FUNCTIONS +from flash.image.embedding.vissl.adapter import VISSLAdapter +from flash.image.embedding.vissl.hooks import SimCLRTrainingSetupHook, TrainingSetupHook + +if _VISSL_AVAILABLE: + from vissl.hooks.dino_hooks import DINOHook + from vissl.hooks.moco_hooks import MoCoHook + from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook + + +def dino(head: str = "dino_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("dino_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [DINOHook(), TrainingSetupHook()] + + +def swav(head: str = "swav_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook()] + + +def simclr(head: str = "simclr_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("simclr_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [SimCLRTrainingSetupHook()] + + +def moco(head: str = "simclr_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("moco_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return ( + loss_fn, + head, + [MoCoHook(loss_fn.loss_config.momentum, loss_fn.loss_config.shuffle_batch), TrainingSetupHook()], + ) + + +def barlow_twins(head: str = "barlow_twins_head", **kwargs): + loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("barlow_twins_loss")(**kwargs) + head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) + + return loss_fn, head, [TrainingSetupHook()] + + +def register_vissl_strategies(register: FlashRegistry): + if _VISSL_AVAILABLE: + for training_strategy in (dino, swav, simclr, moco, barlow_twins): + register(training_strategy, name=training_strategy.__name__, adapter=VISSLAdapter, providers=_VISSL) diff --git a/flash/image/embedding/transforms/__init__.py b/flash/image/embedding/transforms/__init__.py new file mode 100644 index 0000000000..79657f9491 --- /dev/null +++ b/flash/image/embedding/transforms/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.embedding.transforms.vissl_transforms import register_vissl_transforms # noqa: F401 + +IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms") +register_vissl_transforms(IMAGE_EMBEDDER_TRANSFORMS) diff --git a/flash/image/embedding/transforms/vissl_transforms.py b/flash/image/embedding/transforms/vissl_transforms.py new file mode 100644 index 0000000000..8e54354a4f --- /dev/null +++ b/flash/image/embedding/transforms/vissl_transforms.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Callable, Optional, Sequence + +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.image.embedding.vissl.transforms import moco_collate_fn, multicrop_collate_fn, simclr_collate_fn + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import TRANSFORM_REGISTRY + + +def simclr_transform( + total_num_crops: int = 2, + num_crops: Sequence[int] = [2], + size_crops: Sequence[int] = [224], + crop_scales: Sequence[Sequence[float]] = [[0.4, 1]], + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + normalize: Optional[nn.Module] = None, + collate_fn: Callable = simclr_collate_fn, +) -> nn.Module: + """For simclr, barlow twins and moco.""" + transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_num_crops=total_num_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + gaussian_blur=gaussian_blur, + jitter_strength=jitter_strength, + normalize=normalize, + ) + + return transform, collate_fn + + +def swav_transform( + total_num_crops: int = 8, + num_crops: Sequence[int] = [2, 6], + size_crops: Sequence[int] = [224, 96], + crop_scales: Sequence[Sequence[float]] = [[0.4, 1], [0.05, 0.4]], + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + normalize: Optional[nn.Module] = None, + collate_fn: Callable = multicrop_collate_fn, +) -> nn.Module: + """For swav and dino.""" + transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_num_crops=total_num_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + gaussian_blur=gaussian_blur, + jitter_strength=jitter_strength, + normalize=normalize, + ) + + return transform, collate_fn + + +barlow_twins_transform = partial(simclr_transform, collate_fn=simclr_collate_fn) +moco_transform = partial(simclr_transform, collate_fn=moco_collate_fn) +dino_transform = partial(swav_transform, total_num_crops=10, num_crops=[2, 8], collate_fn=multicrop_collate_fn) + + +transforms = [ + "simclr_transform", + "swav_transform", + "barlow_twins_transform", + "moco_transform", + "dino_transform", +] + + +def register_vissl_transforms(register: FlashRegistry): + for idx, transform in enumerate( + ( + simclr_transform, + swav_transform, + barlow_twins_transform, + moco_transform, + dino_transform, + ) + ): + register(transform, name=transforms[idx]) diff --git a/flash/image/embedding/vissl/__init__.py b/flash/image/embedding/vissl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py new file mode 100644 index 0000000000..31a880a572 --- /dev/null +++ b/flash/image/embedding/vissl/adapter.py @@ -0,0 +1,209 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Union + +import torch +import torch.nn as nn + +from flash.core.adapter import Adapter +from flash.core.data.data_source import DefaultDataKeys +from flash.core.model import Task +from flash.core.utilities.imports import _VISSL_AVAILABLE +from flash.image.embedding.vissl.hooks import AdaptVISSLHooks + +if _VISSL_AVAILABLE: + from classy_vision.hooks.classy_hook import ClassyHook + from classy_vision.losses import ClassyLoss + from vissl.config.attr_dict import AttrDict + from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel +else: + ClassyLoss = object + ClassyHook = object + + +class MockVISSLTask: + """Mock task class from VISSL to support loss, configs, base_model, last batch etc.""" + + def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None: + self.vissl_adapter = vissl_adapter + self.loss = vissl_loss + self.config = task_config + self.base_model = vissl_model + self.model = self.base_model # set by property in ClassyTask + + # set using trainingsetuphook + self.device = None + + self.iteration = 0 + self.max_iteration = 1 # set by training setup hook + + # set for momentum teacher based hooks + self.last_batch = AttrDict({"sample": AttrDict({"input": None, "data_momentum": None})}) + + +class VISSLAdapter(Adapter, AdaptVISSLHooks): + """The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL. + + Also inherits from ``AdaptVISSLHooks`` to support VISSL hooks. + """ + + required_extras: str = "image" + + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + loss_fn: ClassyLoss, + hooks: List[ClassyHook], + ) -> None: + + Adapter.__init__(self) + + self.model_config = self.get_model_config_template() + self.optimizer_config = AttrDict({}) + + self.backbone = backbone + self.head = [head] if not isinstance(head, list) else head + self.loss_fn = loss_fn + self.hooks = hooks + + self.model_config.TRUNK = self.backbone.model_config.TRUNK + self.model_config.HEAD = self.head[0].model_config.HEAD + self.task_config = AttrDict( + { + "MODEL": self.model_config, + "OPTIMIZER": self.optimizer_config, + "LOSS": AttrDict( + { + "name": self.loss_fn.loss_name, + self.loss_fn.loss_name: self.loss_fn.loss_config, + } + ), + } + ) + + self.vissl_base_model = BaseSSLMultiInputOutputModel(self.model_config, self.optimizer_config) + # patch backbone and head + self.vissl_base_model.trunk = backbone + self.vissl_base_model.heads = nn.ModuleList(self.head) + + self.vissl_task = MockVISSLTask(self, self.loss_fn, self.task_config, self.vissl_base_model) + + AdaptVISSLHooks.__init__(self, hooks=hooks, task=self.vissl_task) + + @classmethod + def from_task( + cls, + task: Task, + loss_fn: ClassyLoss, + backbone: nn.Module, + head: Union[nn.Module, List[nn.Module]], + hooks: List[ClassyHook], + ) -> Adapter: + result = cls( + backbone=backbone, + head=head, + loss_fn=loss_fn, + hooks=hooks, + ) + + result.__dict__["adapter_task"] = task + + return result + + @staticmethod + def get_model_config_template(): + cfg = AttrDict( + { + "SINGLE_PASS_EVERY_CROP": False, + "INPUT_TYPE": "rgb", + "MULTI_INPUT_HEAD_MAPPING": [], + "TRUNK": AttrDict({}), + "HEAD": AttrDict( + { + "PARAMS": [], + "BATCHNORM_EPS": 1e-5, + "BATCHNORM_MOMENTUM": 0.1, + "PARAMS_MULTIPLIER": 1.0, + } + ), + "FEATURE_EVAL_SETTINGS": AttrDict( + { + "EVAL_MODE_ON": False, + "EXTRACT_TRUNK_FEATURES_ONLY": False, + } + ), + "_MODEL_INIT_SEED": 0, + "ACTIVATION_CHECKPOINTING": AttrDict( + { + "USE_ACTIVATION_CHECKPOINTING": False, + "NUM_ACTIVATION_CHECKPOINTING_SPLITS": 2, + } + ), + } + ) + + return cfg + + def forward(self, batch: torch.Tensor) -> Any: + return self.vissl_base_model.trunk(batch, [])[0] + + def ssl_forward(self, batch) -> Any: + model_output = self.vissl_base_model(batch) + + # vissl-specific + if len(model_output) == 1: + model_output = model_output[0] + + return model_output + + def shared_step(self, batch: Any, train: bool = True) -> Any: + out = self.ssl_forward(batch[DefaultDataKeys.INPUT]) + + # for moco and dino + self.task.last_batch["sample"]["input"] = batch[DefaultDataKeys.INPUT] + if "data_momentum" in batch.keys(): + self.task.last_batch["sample"]["data_momentum"] = [batch["data_momentum"]] + + if train: + # call forward hook from VISSL (momentum updates) + for hook in self.hooks: + hook.on_forward(self.vissl_task) + + loss = self.loss_fn(out, target=None) + + return loss + + def training_step(self, batch: Any, batch_idx: int) -> Any: + loss = self.shared_step(batch) + self.adapter_task.log_dict({"train_loss": loss.item()}) + + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + loss = self.shared_step(batch, train=False) + self.adapter_task.log_dict({"val_loss": loss}) + + return loss + + def test_step(self, batch: Any, batch_idx: int) -> None: + loss = self.shared_step(batch, train=False) + self.adapter_task.log_dict({"test_loss": loss}) + + return loss + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + input_image = batch[DefaultDataKeys.INPUT] + + return self(input_image) diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py new file mode 100644 index 0000000000..bd9931d886 --- /dev/null +++ b/flash/image/embedding/vissl/hooks.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List + +import torch +from pytorch_lightning.core.hooks import ModelHooks + +import flash +from flash.core.utilities.imports import _VISSL_AVAILABLE + +if _VISSL_AVAILABLE: + from classy_vision.hooks.classy_hook import ClassyHook +else: + + class ClassyHook: + _noop = object + + +class TrainingSetupHook(ClassyHook): + on_start = ClassyHook._noop + on_phase_start = ClassyHook._noop + on_loss_and_meter = ClassyHook._noop + on_backward = ClassyHook._noop + on_step = ClassyHook._noop + on_phase_end = ClassyHook._noop + on_end = ClassyHook._noop + on_update = ClassyHook._noop + on_forward = ClassyHook._noop + + def __init__(self): + super().__init__() + + @torch.no_grad() + def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> None: + lightning_module = task.vissl_adapter.adapter_task + task.device = lightning_module.device + + # get around vissl distributed training by setting MockTask flags + num_nodes = lightning_module.trainer.num_nodes + accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids + accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1 + task.world_size = num_nodes * accelerator_per_node + + if lightning_module.trainer.max_epochs is None: + lightning_module.trainer.max_epochs = 1 + + task.max_iteration = lightning_module.trainer.max_epochs * lightning_module.trainer.num_training_batches + + +class SimCLRTrainingSetupHook(TrainingSetupHook): + def __init__(self): + super().__init__() + + @torch.no_grad() + def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> None: + super().on_start(task) + + lightning_module = task.vissl_adapter.adapter_task + + # specific to simclr in VISSL + task.loss.info_criterion.buffer_params.effective_batch_size = ( + task.world_size * 2 * lightning_module.trainer.datamodule.batch_size + ) + task.loss.info_criterion.buffer_params.world_size = task.world_size + + task.loss.info_criterion.precompute_pos_neg_mask() + + +class AdaptVISSLHooks(ModelHooks): + def __init__(self, hooks: List[ClassyHook], task) -> None: + super().__init__() + + self.hooks = hooks + self.task = task + + def on_train_start(self) -> None: + for hook in self.hooks: + hook.on_start(self.task) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.task.iteration += 1 + + def on_train_epoch_end(self) -> None: + for hook in self.hooks: + hook.on_update(self.task) diff --git a/flash/core/integrations/vissl/transforms/__init__.py b/flash/image/embedding/vissl/transforms/__init__.py similarity index 50% rename from flash/core/integrations/vissl/transforms/__init__.py rename to flash/image/embedding/vissl/transforms/__init__.py index 804689456e..447aef4fa7 100644 --- a/flash/core/integrations/vissl/transforms/__init__.py +++ b/flash/image/embedding/vissl/transforms/__init__.py @@ -1,9 +1,12 @@ from flash.core.utilities.imports import _VISSL_AVAILABLE # noqa: F401 +from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401 +from flash.image.embedding.vissl.transforms.utilities import ( # noqa: F401 + moco_collate_fn, + multicrop_collate_fn, + simclr_collate_fn, +) if _VISSL_AVAILABLE: from classy_vision.dataset.transforms import register_transform # noqa: F401 - from flash.core.integrations.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401 - from flash.core.integrations.vissl.transforms.utilities import vissl_collate_fn # noqa: F401 - register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform) diff --git a/flash/core/integrations/vissl/transforms/multicrop.py b/flash/image/embedding/vissl/transforms/multicrop.py similarity index 97% rename from flash/core/integrations/vissl/transforms/multicrop.py rename to flash/image/embedding/vissl/transforms/multicrop.py index f6dda5c8b5..969cd2a6d7 100644 --- a/flash/core/integrations/vissl/transforms/multicrop.py +++ b/flash/image/embedding/vissl/transforms/multicrop.py @@ -27,6 +27,7 @@ class StandardMultiCropSSLTransform(nn.Module): crops. This transform was proposed in SwAV - https://arxiv.org/abs/2006.09882 + This transform can act as a base transform class for SimCLR, SwAV, MoCo, Barlow Twins and DINO from VISSL. This transform has been modified from the ImgPilToMultiCrop code present at https://github.com/facebookresearch/vissl/blob/master/vissl/data/ssl_transforms/img_pil_to_multicrop.py diff --git a/flash/image/embedding/vissl/transforms/utilities.py b/flash/image/embedding/vissl/transforms/utilities.py new file mode 100644 index 0000000000..7909cbdda2 --- /dev/null +++ b/flash/image/embedding/vissl/transforms/utilities.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from flash.core.data.data_source import DefaultDataKeys + + +def vissl_collate_helper(samples): + result = [] + + for batch_ele in samples: + _batch_ele_dict = {} + _batch_ele_dict.update(batch_ele) + _batch_ele_dict[DefaultDataKeys.INPUT] = -1 + + result.append(_batch_ele_dict) + + return torch.utils.data._utils.collate.default_collate(result) + + +def multicrop_collate_fn(samples): + """Multi-crop collate function for VISSL integration. + + Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + """ + result = vissl_collate_helper(samples) + + inputs = [[] for _ in range(len(samples[0][DefaultDataKeys.INPUT]))] + for batch_ele in samples: + multi_crop_imgs = batch_ele[DefaultDataKeys.INPUT] + + for idx, crop in enumerate(multi_crop_imgs): + inputs[idx].append(crop) + + for idx, ele in enumerate(inputs): + inputs[idx] = torch.stack(ele) + + result[DefaultDataKeys.INPUT] = inputs + + return result + + +def simclr_collate_fn(samples): + """Multi-crop collate function for VISSL integration. + + Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + """ + result = vissl_collate_helper(samples) + + inputs = [] + num_views = len(samples[0][DefaultDataKeys.INPUT]) + view_idx = 0 + while view_idx < num_views: + for batch_ele in samples: + imgs = batch_ele[DefaultDataKeys.INPUT] + inputs.append(imgs[view_idx]) + + view_idx += 1 + + result[DefaultDataKeys.INPUT] = torch.stack(inputs) + + return result + + +def moco_collate_fn(samples): + """MOCO collate function for VISSL integration. + + Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + """ + result = vissl_collate_helper(samples) + + inputs = [] + for batch_ele in samples: + inputs.append(torch.stack(batch_ele[DefaultDataKeys.INPUT])) + + result[DefaultDataKeys.INPUT] = torch.stack(inputs).squeeze()[:, 0, :, :, :].squeeze() + result["data_momentum"] = torch.stack(inputs).squeeze()[:, 1, :, :, :].squeeze() + + return result diff --git a/flash/image/instance_segmentation/cli.py b/flash/image/instance_segmentation/cli.py index 63b2538d09..97960ae5c9 100644 --- a/flash/image/instance_segmentation/cli.py +++ b/flash/image/instance_segmentation/cli.py @@ -28,7 +28,7 @@ def from_pets( val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, parser: Optional[Callable] = None, **preprocess_kwargs, ) -> InstanceSegmentationData: diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 91a1e8eeb1..ab68c3d5ff 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -92,7 +92,7 @@ def from_coco( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the @@ -171,7 +171,7 @@ def from_voc( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the diff --git a/flash/image/keypoint_detection/cli.py b/flash/image/keypoint_detection/cli.py index 5f8d0feade..959328a51c 100644 --- a/flash/image/keypoint_detection/cli.py +++ b/flash/image/keypoint_detection/cli.py @@ -27,7 +27,7 @@ def from_biwi( val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, parser: Optional[Callable] = None, **preprocess_kwargs, ) -> KeypointDetectionData: diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 0e7f700e4a..97948a7d40 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -90,7 +90,7 @@ def from_coco( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data diff --git a/flash/image/segmentation/cli.py b/flash/image/segmentation/cli.py index 64cb0c3d93..2e92877015 100644 --- a/flash/image/segmentation/cli.py +++ b/flash/image/segmentation/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -24,7 +23,7 @@ def from_carla( num_classes: int = 21, val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> SemanticSegmentationData: """Downloads and loads the CARLA capture data set.""" diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 42787df9bc..d9d2a69e57 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -318,7 +318,7 @@ def from_data_source( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ) -> "DataModule": @@ -376,7 +376,7 @@ def from_folders( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **preprocess_kwargs, diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py index 0fec347021..0aab00a4e3 100644 --- a/flash/image/style_transfer/cli.py +++ b/flash/image/style_transfer/cli.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional import flash from flash.core.data.utils import download_data @@ -24,7 +23,7 @@ def from_coco_128( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> StyleTransferData: """Downloads and loads the COCO 128 data set.""" diff --git a/flash/pointcloud/detection/cli.py b/flash/pointcloud/detection/cli.py index 01a4c329ce..1acbef5efa 100644 --- a/flash/pointcloud/detection/cli.py +++ b/flash/pointcloud/detection/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -22,7 +21,7 @@ def from_kitti( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> PointCloudObjectDetectorData: """Downloads and loads the KITTI data set.""" diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 411d3a498f..e565c358b7 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -89,7 +89,7 @@ def from_folders( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, scans_folder_name: Optional[str] = "scans", labels_folder_name: Optional[str] = "labels", diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 155126d785..5555bc1d46 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -163,6 +163,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + **kwargs ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py index 0c4872c3b3..ba7b84f670 100644 --- a/flash/pointcloud/detection/open3d_ml/data_sources.py +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -141,7 +141,7 @@ def clean_fn(path: str) -> str: def predict_load_data(self, data, dataset: Optional[BaseAutoDataset] = None): if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)): return self.load_files(data, dataset) - elif isinstance(data, str) and isdir(data): + if isinstance(data, str) and isdir(data): raise NotImplementedError def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None): diff --git a/flash/pointcloud/segmentation/cli.py b/flash/pointcloud/segmentation/cli.py index 57d1125f9b..26a147d68b 100644 --- a/flash/pointcloud/segmentation/cli.py +++ b/flash/pointcloud/segmentation/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -22,7 +21,7 @@ def from_kitti( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> PointCloudSegmentationData: """Downloads and loads the semantic KITTI data set.""" diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 9342a61758..a8989d9a42 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -192,6 +192,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + **kwargs ) -> DataLoader: if not _POINTCLOUD_AVAILABLE: diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index 63eff2458f..6787b1c8d6 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -22,7 +21,7 @@ def from_titanic( batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> TabularClassificationData: """Downloads and loads the Titanic data set.""" diff --git a/flash/tabular/data.py b/flash/tabular/data.py index da36d726ce..b078344366 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data.sampler import Sampler from flash.core.classification import LabelsState from flash.core.data.callback import BaseDataFetcher @@ -343,7 +344,8 @@ def from_data_frame( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. @@ -372,6 +374,7 @@ def from_data_frame( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -420,6 +423,7 @@ def from_data_frame( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, cat_cols=categorical_fields, num_cols=numerical_fields, target_col=target_fields, @@ -450,7 +454,8 @@ def from_csv( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. @@ -479,6 +484,7 @@ def from_csv( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -506,4 +512,6 @@ def from_csv( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, ) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index f81111bc3c..fc311c0a6a 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -176,7 +176,7 @@ def from_sklearn( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them diff --git a/flash/text/classification/backbones.py b/flash/text/classification/backbones.py new file mode 100644 index 0000000000..0a150feaf7 --- /dev/null +++ b/flash/text/classification/backbones.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# ResNet encoder adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py +# as the official torchvision implementation does not support wide resnet architecture +# found in self-supervised learning model weights +from flash.core.registry import ExternalRegistry, FlashRegistry +from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE +from flash.core.utilities.providers import _HUGGINGFACE + +if _TRANSFORMERS_AVAILABLE: + from transformers import AutoModelForSequenceClassification + +TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + +if _TRANSFORMERS_AVAILABLE: + HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES = ExternalRegistry( + getter=AutoModelForSequenceClassification.from_pretrained, + name="backbones", + providers=_HUGGINGFACE, + ) + TEXT_CLASSIFIER_BACKBONES += HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py index 42499bb53f..0b7be2bd11 100644 --- a/flash/text/classification/cli.py +++ b/flash/text/classification/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -23,7 +22,7 @@ def from_imdb( backbone: str = "prajjwal1/bert-medium", batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> TextClassificationData: """Downloads and loads the IMDB sentiment classification data set.""" @@ -44,7 +43,7 @@ def from_toxic( backbone: str = "unitary/toxic-bert", val_split: float = 0.1, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> TextClassificationData: """Downloads and loads the Jigsaw toxic comments data set.""" diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index a7171af128..085b30988c 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +import pandas as pd import torch from torch import Tensor +from torch.utils.data.sampler import Sampler import flash from flash.core.data.auto_dataset import AutoDataset +from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.core.data.process import Deserializer, Postprocess, Preprocess @@ -26,7 +29,7 @@ from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: - from datasets import DatasetDict, load_dataset + from datasets import Dataset, DatasetDict, load_dataset from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput @@ -216,6 +219,66 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) +class TextDataFrameDataSource(TextDataSource): + @staticmethod + def _multilabel_target(targets, element): + targets = [element.pop(target) for target in targets] + element["labels"] = targets + return element + + def load_data( + self, + data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]], + dataset: Optional[Any] = None, + columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), + ) -> Union[Sequence[Mapping[str, Any]]]: + df, input, target = data + hf_dataset = Dataset.from_pandas(df) + + if not self.predicting: + if isinstance(target, List): + # multi-target + dataset.multi_label = True + hf_dataset = hf_dataset.map(partial(self._multilabel_target, target)) + dataset.num_classes = len(target) + self.set_state(LabelsState(target)) + else: + dataset.multi_label = False + if self.training: + labels = list(sorted(list(set(hf_dataset[target])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + # convert labels to ids + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, target)) + + # Hugging Face models expect target to be named ``labels``. + if target != "labels": + hf_dataset.rename_column_(target, "labels") + + hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=input), batched=True) + hf_dataset.set_format("torch", columns=columns) + + return hf_dataset + + def predict_load_data(self, data: Any, dataset: AutoDataset): + return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + + class TextSentencesDataSource(TextDataSource): def __init__(self, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -267,6 +330,7 @@ def __init__( data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), + "data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length), "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource( backbone=self.backbone, max_length=max_length @@ -317,3 +381,76 @@ class TextClassificationData(DataModule): @property def backbone(self) -> Optional[str]: return getattr(self.preprocess, "backbone", None) + + @classmethod + def from_data_frame( + cls, + input_field: str, + target_fields: Union[str, Sequence[str]], + train_data_frame: Optional[pd.DataFrame] = None, + val_data_frame: Optional[pd.DataFrame] = None, + test_data_frame: Optional[pd.DataFrame] = None, + predict_data_frame: Optional[pd.DataFrame] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given pandas + ``DataFrame`` objects. + + Args: + input_field: The field (column) in the pandas ``DataFrame`` to use for the input. + target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + """ + return cls.from_data_source( + "data_frame", + (train_data_frame, input_field, target_fields), + (val_data_frame, input_field, target_fields), + (test_data_frame, input_field, target_fields), + (predict_data_frame, input_field, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index d551b74b62..2c4bf4b0d4 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -22,23 +22,14 @@ from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer -from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _TEXT_AVAILABLE -from flash.core.utilities.providers import _HUGGINGFACE +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE +from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES from flash.text.ort_callback import ORTCallback -if _TEXT_AVAILABLE: - from transformers import AutoModelForSequenceClassification +if _TRANSFORMERS_AVAILABLE: from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput - HUGGINGFACE_BACKBONES = ExternalRegistry( - AutoModelForSequenceClassification.from_pretrained, - "backbones", - _HUGGINGFACE, - ) -else: - HUGGINGFACE_BACKBONES = FlashRegistry("backbones") - class TextClassifier(ClassificationTask): """The ``TextClassifier`` is a :class:`~flash.Task` for classifying text. For more details, see @@ -64,7 +55,7 @@ class TextClassifier(ClassificationTask): required_extras: str = "text" - backbones: FlashRegistry = FlashRegistry("backbones") + HUGGINGFACE_BACKBONES + backbones: FlashRegistry = TEXT_CLASSIFIER_BACKBONES def __init__( self, diff --git a/flash/text/question_answering/cli.py b/flash/text/question_answering/cli.py index 12932ae930..471cf13eca 100644 --- a/flash/text/question_answering/cli.py +++ b/flash/text/question_answering/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -23,15 +22,15 @@ def from_squad( backbone: str = "distilbert-base-uncased", batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> QuestionAnsweringData: - """Downloads and loads the XSum data set.""" - download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/") - download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/") + """Downloads and loads a tiny subset of the squad V2 data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/") + return QuestionAnsweringData.from_squad_v2( - train_file="./data/train-v2.0.json", - val_file="./data/dev-v2.0.json", + train_file="./data/squad_tiny/train.json", + val_file="./data/squad_tiny/val.json", backbone=backbone, batch_size=batch_size, num_workers=num_workers, diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 1948ca9b21..35f3af8df7 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -645,7 +645,7 @@ def from_squad_v2( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs: Any, ): """Creates a :class:`~flash.text.question_answering.data.QuestionAnsweringData` object from the given data @@ -713,7 +713,7 @@ def from_json( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Sampler] = None, field: Optional[str] = None, **preprocess_kwargs: Any, @@ -814,7 +814,7 @@ def from_csv( preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, ) -> "DataModule": diff --git a/flash/text/seq2seq/summarization/cli.py b/flash/text/seq2seq/summarization/cli.py index 666dd87f40..25003cb58b 100644 --- a/flash/text/seq2seq/summarization/cli.py +++ b/flash/text/seq2seq/summarization/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -23,7 +22,7 @@ def from_xsum( backbone: str = "sshleifer/distilbart-xsum-1-1", batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> SummarizationData: """Downloads and loads the XSum data set.""" diff --git a/flash/text/seq2seq/translation/cli.py b/flash/text/seq2seq/translation/cli.py index 1609cb4de0..66ec698791 100644 --- a/flash/text/seq2seq/translation/cli.py +++ b/flash/text/seq2seq/translation/cli.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -23,7 +22,7 @@ def from_wmt_en_ro( backbone: str = "Helsinki-NLP/opus-mt-en-ro", batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> TranslationData: """Downloads and loads the WMT EN RO data set.""" diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py index 840386506b..3053d0c1ca 100644 --- a/flash/video/classification/cli.py +++ b/flash/video/classification/cli.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -26,7 +25,7 @@ def from_kinetics( clip_duration: int = 1, decode_audio: bool = False, batch_size: int = 4, - num_workers: Optional[int] = None, + num_workers: int = 0, **preprocess_kwargs, ) -> VideoClassificationData: """Downloads and loads the Kinetics data set.""" diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py deleted file mode 100644 index 15cc3b9fc7..0000000000 --- a/flash_examples/custom_task.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import torch -from pytorch_lightning import seed_everything -from torch import nn, Tensor - -import flash -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess -from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _SKLEARN_AVAILABLE - -if _SKLEARN_AVAILABLE: - from sklearn import datasets -else: - raise ModuleNotFoundError("Please pip install scikit-learn") - -seed_everything(42) - -ND = np.ndarray - - -class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.2, metrics=None): - # what kind of model do we want? - model = nn.Linear(num_inputs, 1) - - # what loss function do we want? - loss_fn = torch.nn.functional.mse_loss - - # what optimizer to do we want? - optimizer = torch.optim.Adam - - super().__init__( - model=model, - loss_fn=loss_fn, - optimizer=optimizer, - metrics=metrics, - learning_rate=learning_rate, - ) - - def training_step(self, batch: Any, batch_idx: int) -> Any: - return super().training_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def validation_step(self, batch: Any, batch_idx: int) -> None: - return super().validation_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def test_step(self, batch: Any, batch_idx: int) -> None: - return super().test_step( - (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), - batch_idx, - ) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return super().predict_step( - batch[DefaultDataKeys.INPUT], - batch_idx, - dataloader_idx, - ) - - def forward(self, x): - # we don't actually need to override this method for this example - return self.model(x) - - -class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: - if self.training: - dataset.num_inputs = data[0].shape[1] - return [{DefaultDataKeys.INPUT: x, DefaultDataKeys.TARGET: y} for x, y in zip(*data)] - - @staticmethod - def predict_load_data(data: ND) -> List[Dict[str, Any]]: - return [{DefaultDataKeys.INPUT: x} for x in data] - - -class NumpyPreprocess(Preprocess): - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - ): - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_sources={DefaultDataSources.NUMPY: NumpyDataSource()}, - default_data_source=DefaultDataSources.NUMPY, - ) - - @staticmethod - def to_float(x: Tensor): - return x.float() - - @staticmethod - def format_targets(x: Tensor): - return x.unsqueeze(0) - - @property - def to_tensor(self) -> Dict[str, Callable]: - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys( - DefaultDataKeys.INPUT, - torch.from_numpy, - self.to_float, - ), - ApplyToKeys( - DefaultDataKeys.TARGET, - torch.as_tensor, - self.to_float, - self.format_targets, - ), - ), - } - - def default_transforms(self) -> Optional[Dict[str, Callable]]: - return self.to_tensor - - def get_state_dict(self) -> Dict[str, Any]: - return self.transforms - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - return cls(*state_dict) - - -class NumpyDataModule(flash.DataModule): - - preprocess_cls = NumpyPreprocess - - -x, y = datasets.load_diabetes(return_X_y=True) -datamodule = NumpyDataModule.from_numpy(x, y) -model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) - -trainer = flash.Trainer( - max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count() -) -trainer.fit(model, datamodule=datamodule) - -predict_data = np.array( - [ - [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], - ] -) - -predictions = model.predict(predict_data) -print(predictions) -# out: [tensor([188.9760]), tensor([196.1777]), tensor([161.3590]), tensor([130.7312]), tensor([149.0340])] diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index 5a4de94fcf..72e81e2bde 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -11,15 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from torchvision.datasets import CIFAR10 + +import flash from flash.core.data.utils import download_data -from flash.image import ImageEmbedder +from flash.image import ImageClassificationData, ImageEmbedder -# 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") +# 1. Download the data and prepare the datamodule +datamodule = ImageClassificationData.from_datasets( + train_dataset=CIFAR10(".", download=True), + batch_size=16, +) # 2. Build the task -embedder = ImageEmbedder(backbone="resnet101") +embedder = ImageEmbedder( + backbone="resnet", + training_strategy="barlow_twins", + head="simclr_head", + pretraining_transform="barlow_twins_transform", + training_strategy_kwargs={"latent_embedding_dim": 128}, + pretraining_transform_kwargs={"size_crops": [196]}, +) + +# 3. Create the trainer and pre-train the encoder +# use accelerator='ddp' when using GPU(s), +# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp') +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.fit(embedder, datamodule=datamodule) + +# 4. Save the model! +trainer.save_checkpoint("image_embedder_model.pt") + +# 5. Download the downstream prediction dataset and generate embeddings +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") -# 3. Generate an embedding from an image path. -embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) +embeddings = embedder.predict( + [ + "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg", + "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg", + ] +) +# list of embeddings for images sent to the predict function print(embeddings) diff --git a/flash_examples/integrations/baal/image_classification_active_learning.py b/flash_examples/integrations/baal/image_classification_active_learning.py new file mode 100644 index 0000000000..dd9e82fc57 --- /dev/null +++ b/flash_examples/integrations/baal/image_classification_active_learning.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.classification import Probabilities +from flash.core.data.utils import download_data +from flash.image import ImageClassificationData, ImageClassifier +from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop + +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + +# Implement the research use-case where we mask labels from labelled dataset. +datamodule = ActiveLearningDataModule( + ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), + val_split=0.1, +) + +# 2. Build the task +head = torch.nn.Sequential( + torch.nn.Dropout(p=0.1), + torch.nn.Linear(512, datamodule.num_classes), +) +model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities()) + + +# 3.1 Create the trainer +trainer = flash.Trainer(max_epochs=3) + +# 3.2 Create the active learning loop and connect it to the trainer +active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) +active_learning_loop.connect(trainer.fit_loop) +trainer.fit_loop = active_learning_loop + +# 3.3 Finetune +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Predict what's on a few images! ants or bees? +predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py new file mode 100644 index 0000000000..38bd6c2e7e --- /dev/null +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154 + +import warnings + +import kornia.augmentation as Ka +import kornia.geometry as Kg +import learn2learn as l2l +import torch +import torchvision +from torch import nn + +import flash +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys, kornia_collate +from flash.image import ImageClassificationData, ImageClassifier + +warnings.simplefilter("ignore") + +# download MiniImagenet +train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True) +val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True) +test_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="test", download=True) + +train_transform = { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + Kg.Resize((196, 196)), + # SPATIAL + Ka.RandomHorizontalFlip(p=0.25), + Ka.RandomRotation(degrees=90.0, p=0.25), + Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), + Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), + # PIXEL-LEVEL + Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness + Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation + Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast + Ka.ColorJitter(hue=1 / 30, p=0.25), # hue + Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), + Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), + ), + "collate": kornia_collate, + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + Ka.RandomHorizontalFlip(p=0.25), + ), +} + +# construct datamodule +datamodule = ImageClassificationData.from_tensors( + train_data=train_dataset.x, + train_targets=torch.from_numpy(train_dataset.y.astype(int)), + val_data=val_dataset.x, + val_targets=torch.from_numpy(val_dataset.y.astype(int)), + test_data=test_dataset.x, + test_targets=torch.from_numpy(test_dataset.y.astype(int)), + num_workers=4, + train_transform=train_transform, +) + +model = ImageClassifier( + backbone="resnet18", + training_strategy="prototypicalnetworks", + training_strategy_kwargs={ + "epoch_length": 10 * 16, + "meta_batch_size": 4, + "num_tasks": 200, + "test_num_tasks": 2000, + "ways": datamodule.num_classes, + "shots": 1, + "test_ways": 5, + "test_shots": 1, + "test_queries": 15, + }, + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, +) + +trainer = flash.Trainer( + max_epochs=200, + gpus=2, + acceletator="ddp_shared", + precision=16, +) +trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") diff --git a/flash_examples/question_answering.py b/flash_examples/question_answering.py index 78e119d108..8620e5aed1 100644 --- a/flash_examples/question_answering.py +++ b/flash_examples/question_answering.py @@ -16,12 +16,11 @@ from flash.text import QuestionAnsweringData, QuestionAnsweringTask # 1. Create the DataModule -download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/") -download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/") +download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/") datamodule = QuestionAnsweringData.from_squad_v2( - train_file="./data/train-v2.0.json", - val_file="./data/dev-v2.0.json", + train_file="./data/squad_tiny/train.json", + val_file="./data/squad_tiny/val.json", ) # 2. Build the task diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 9b0ff39748..7704ebb7a2 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -5,3 +5,6 @@ vissl>=0.1.5 icevision>=0.8 icedata effdet +albumentations +learn2learn +baal diff --git a/tests/core/integrations/vissl/test_strategies.py b/tests/core/integrations/vissl/test_strategies.py new file mode 100644 index 0000000000..443fe6b7e1 --- /dev/null +++ b/tests/core/integrations/vissl/test_strategies.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.image.embedding.heads.vissl_heads import SimCLRHead +from flash.image.embedding.vissl.hooks import TrainingSetupHook + +if _VISSL_AVAILABLE: + from vissl.hooks.dino_hooks import DINOHook + from vissl.hooks.moco_hooks import MoCoHook + from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook + from vissl.losses.barlow_twins_loss import BarlowTwinsLoss + from vissl.losses.dino_loss import DINOLoss + from vissl.losses.moco_loss import MoCoLoss + from vissl.losses.swav_loss import SwAVLoss + from vissl.models.heads.swav_prototypes_head import SwAVPrototypesHead + + from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES +else: + DINOHook = object + MoCoHook = object + NormalizePrototypesHook = object + SwAVUpdateQueueScoresHook = object + + BarlowTwinsLoss = object + DINOLoss = object + MoCoLoss = object + SwAVLoss = object + + SwAVPrototypesHead = object + + IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies") + + +@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.parametrize( + "training_strategy, head_name, loss_fn_class, head_class, hooks_list", + [ + ("barlow_twins", "barlow_twins_head", BarlowTwinsLoss, SimCLRHead, [TrainingSetupHook]), + ("moco", "moco_head", MoCoLoss, SimCLRHead, [MoCoHook, TrainingSetupHook]), + ( + "swav", + "swav_head", + SwAVLoss, + SwAVPrototypesHead, + [SwAVUpdateQueueScoresHook, NormalizePrototypesHook, TrainingSetupHook], + ), + ("dino", "dino_head", DINOLoss, SwAVPrototypesHead, [DINOHook, TrainingSetupHook]), + ], +) +def test_vissl_strategies(tmpdir, training_strategy, head_name, loss_fn_class, head_class, hooks_list): + ret_loss_fn, ret_head, ret_hooks = IMAGE_EMBEDDER_STRATEGIES.get(training_strategy)(head=head_name) + + assert isinstance(ret_loss_fn, loss_fn_class) + assert isinstance(ret_head, head_class) + for hook in hooks_list: + hook_present = 0 + for ret_hook in ret_hooks: + if isinstance(ret_hook, hook): + hook_present = 1 + + assert hook_present == 1 diff --git a/tests/core/integrations/vissl/test_transforms.py b/tests/core/integrations/vissl/test_transforms.py index d40913f58f..fa379acda3 100644 --- a/tests/core/integrations/vissl/test_transforms.py +++ b/tests/core/integrations/vissl/test_transforms.py @@ -14,53 +14,28 @@ import pytest from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import DefaultPreprocess -from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE -from flash.image import ImageClassificationData - -if _TORCHVISION_AVAILABLE: - from torchvision.datasets import FakeData - -if _VISSL_AVAILABLE: - from classy_vision.dataset.transforms import TRANSFORM_REGISTRY - - from flash.core.integrations.vissl.transforms import vissl_collate_fn +from tests.image.embedding.utils import ssl_datamodule @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") def test_multicrop_input_transform(): batch_size = 8 - total_crops = 6 + total_num_crops = 6 num_crops = [2, 4] size_crops = [160, 96] crop_scales = [[0.4, 1], [0.05, 0.4]] - multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( - total_crops, num_crops, size_crops, crop_scales - ) - - to_tensor_transform = ApplyToKeys( - DefaultDataKeys.INPUT, - multi_crop_transform, - ) - preprocess = DefaultPreprocess( - train_transform={ - "to_tensor_transform": to_tensor_transform, - "collate": vissl_collate_fn, - } - ) - - datamodule = ImageClassificationData.from_datasets( - train_dataset=FakeData(), - preprocess=preprocess, + train_dataloader = ssl_datamodule( batch_size=batch_size, - ) - - train_dataloader = datamodule._train_dataloader() + total_num_crops=total_num_crops, + num_crops=num_crops, + size_crops=size_crops, + crop_scales=crop_scales, + )._train_dataloader() batch = next(iter(train_dataloader)) - assert len(batch[DefaultDataKeys.INPUT]) == total_crops + assert len(batch[DefaultDataKeys.INPUT]) == total_num_crops assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size] diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 156669a657..7c77f69075 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform - import torch from flash import DataModule @@ -54,7 +52,4 @@ def test_dataloaders(): def test_cpu_count_none(): train_ds = DummyDataset() dm = DataModule(train_ds, num_workers=None) - if platform.system() == "Darwin" or platform.system() == "Windows": - assert dm.num_workers == 0 - else: - assert dm.num_workers > 0 + assert dm.num_workers == 0 diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3d3b53b111..148a8c06a3 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -248,14 +248,15 @@ def test_task_datapipeline_save(tmpdir): @pytest.mark.parametrize( ["cls", "filename"], [ - pytest.param( - ImageClassifier, - "image_classification_model.pt", - marks=pytest.mark.skipif( - not _IMAGE_TESTING, - reason="image packages aren't installed", - ), - ), + # needs to be updated. + # pytest.param( + # ImageClassifier, + # "image_classification_model.pt", + # marks=pytest.mark.skipif( + # not _IMAGE_TESTING, + # reason="image packages aren't installed", + # ), + # ), pytest.param( TabularClassifier, "tabular_classification_model.pt", diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index 5fe061c678..4923099df4 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -17,7 +17,7 @@ import pytest -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from tests.examples.utils import run_test root = Path(__file__).parent.parent.parent @@ -34,6 +34,11 @@ not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" ), ), + pytest.param( + "baal", + "image_classification_active_learning.py", + marks=pytest.mark.skipif(not (_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="baal library isn't installed"), + ), ], ) def test_integrations(tmpdir, folder, file): diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index adbb1025a5..1060e43eb2 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -18,7 +18,7 @@ import pytest import flash -from flash.core.utilities.imports import _SKLEARN_AVAILABLE +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _SKLEARN_AVAILABLE from tests.examples.utils import run_test from tests.helpers.utils import ( _AUDIO_TESTING, @@ -35,9 +35,6 @@ @pytest.mark.parametrize( "file", [ - pytest.param( - "custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") - ), pytest.param( "audio_classification.py", marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), @@ -54,6 +51,12 @@ "image_classification_multi_label.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), + pytest.param( + "image_classification_meta_learning.py.py", + marks=pytest.mark.skipif( + not (_IMAGE_TESTING and _LEARN2LEARN_AVAILABLE), reason="image/learn2learn libraries aren't installed" + ), + ), # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "question_answering.py", diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py new file mode 100644 index 0000000000..aae92a847c --- /dev/null +++ b/tests/image/classification/test_active_learning.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import pytest +from pytorch_lightning import seed_everything +from torch import nn + +import flash +from flash.core.classification import Probabilities +from flash.core.utilities.imports import _BAAL_AVAILABLE +from flash.image import ImageClassificationData, ImageClassifier +from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop +from tests.helpers.utils import _IMAGE_TESTING +from tests.image.classification.test_data import _rand_image + +# ======== Mock functions ======== + + +@pytest.mark.skipif(not (_IMAGE_TESTING and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed.") +def test_active_learning_training(tmpdir): + seed_everything(42) + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + pa_1 = train_dir / "a" / "1.png" + pa_2 = train_dir / "a" / "2.png" + pb_1 = train_dir / "b" / "1.png" + pb_2 = train_dir / "b" / "2.png" + image_size = (96, 96) + _rand_image(image_size).save(pa_1) + _rand_image(image_size).save(pa_2) + + (train_dir / "b").mkdir() + _rand_image(image_size).save(pb_1) + _rand_image(image_size).save(pb_2) + + n = 5 + dm = ImageClassificationData.from_files( + train_files=[str(pa_1)] * n + [str(pa_2)] * n + [str(pb_1)] * n + [str(pb_2)] * n, + train_targets=[0] * n + [1] * n + [2] * n + [3] * n, + test_files=[str(pa_1)] * n, + test_targets=[0] * n, + batch_size=2, + num_workers=0, + image_size=image_size, + ) + + active_learning_dm = ActiveLearningDataModule( + dm, + val_split=0.5, + ) + + head = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(512, active_learning_dm.num_classes), + ) + model = ImageClassifier( + backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities() + ) + trainer = flash.Trainer(max_epochs=3) + + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) + active_learning_loop.connect(trainer.fit_loop) + trainer.fit_loop = active_learning_loop + + trainer.finetune(model, datamodule=active_learning_dm, strategy="freeze") + assert len(active_learning_dm._dataset) == 15 + assert active_learning_loop.progress.total.completed == 3 + labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"] + assert isinstance(labelled, np.ndarray) + assert len(active_learning_dm.val_dataloader()) == 4 diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 99bf240646..c7773b7377 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -21,8 +21,9 @@ import torch.nn as nn from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.transforms import ApplyToKeys +from flash.core.data.transforms import ApplyToKeys, merge_transforms from flash.core.utilities.imports import ( + _ALBUMENTATIONS_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, @@ -30,6 +31,7 @@ _TORCHVISION_AVAILABLE, ) from flash.image import ImageClassificationData +from flash.image.classification.transforms import AlbumentationsAdapter, default_transforms from tests.helpers.utils import _IMAGE_TESTING if _TORCHVISION_AVAILABLE: @@ -42,6 +44,9 @@ if _FIFTYONE_AVAILABLE: import fiftyone as fo +if _ALBUMENTATIONS_AVAILABLE: + import albumentations + def _dummy_image_loader(_): return torch.rand(3, 196, 196) @@ -569,3 +574,40 @@ def test_from_bad_csv_no_image(bad_csv_no_image): num_workers=0, ) _ = next(iter(img_data.train_dataloader())) + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ALBUMENTATIONS_AVAILABLE, reason="albumentations isn't installed.") +def test_albumentations_mixup(single_target_csv): + def mixup(batch, alpha=1.0): + images = batch["input"] + targets = batch["target"].float().unsqueeze(1) + + lam = np.random.beta(alpha, alpha) + perm = torch.randperm(images.size(0)) + + batch["input"] = images * lam + images[perm] * (1 - lam) + batch["target"] = targets * lam + targets[perm] * (1 - lam) + for e in batch["metadata"]: + e.update({"lam": lam}) + return batch + + train_transform = { + # applied only on images as ApplyToKeys is used with `input` + "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), + "per_batch_transform": mixup, + } + # merge the default transform for this task with new one. + train_transform = merge_transforms(default_transforms((256, 256)), train_transform) + + img_data = ImageClassificationData.from_csv( + "image", + "target", + train_file=single_target_csv, + batch_size=2, + num_workers=0, + train_transform=train_transform, + ) + + batch = next(iter(img_data.train_dataloader())) + assert "lam" in batch["metadata"][0] diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py new file mode 100644 index 0000000000..746880b4be --- /dev/null +++ b/tests/image/classification/test_training_strategies.py @@ -0,0 +1,115 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest +import torch +from torch.utils.data import DataLoader + +from flash import Trainer +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE +from flash.image import ImageClassificationData, ImageClassifier +from flash.image.classification.adapters import TRAINING_STRATEGIES +from tests.helpers.utils import _IMAGE_TESTING +from tests.image.classification.test_data import _rand_image + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + def __getitem__(self, index): + return { + DefaultDataKeys.INPUT: torch.rand(3, 96, 96), + DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(), + } + + def __len__(self) -> int: + return 2 + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_default_strategies(tmpdir): + num_classes = 10 + ds = DummyDataset() + model = ImageClassifier(num_classes, backbone="resnet50") + + trainer = Trainer(fast_dev_run=2) + trainer.fit(model, train_dataloader=DataLoader(ds)) + + +@pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") +def test_learn2learn_training_strategies_registry(): + assert TRAINING_STRATEGIES.available_keys() == ["anil", "default", "maml", "metaoptnet", "prototypicalnetworks"] + + +def _test_learn2learning_training_strategies(gpus, accelerator, training_strategy, tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + pa_1 = train_dir / "a" / "1.png" + pa_2 = train_dir / "a" / "2.png" + pb_1 = train_dir / "b" / "1.png" + pb_2 = train_dir / "b" / "2.png" + image_size = (96, 96) + _rand_image(image_size).save(pa_1) + _rand_image(image_size).save(pa_2) + + (train_dir / "b").mkdir() + _rand_image(image_size).save(pb_1) + _rand_image(image_size).save(pb_2) + + n = 5 + + dm = ImageClassificationData.from_files( + train_files=[str(pa_1)] * n + [str(pa_2)] * n + [str(pb_1)] * n + [str(pb_2)] * n, + train_targets=[0] * n + [1] * n + [2] * n + [3] * n, + batch_size=1, + num_workers=0, + image_size=image_size, + ) + + model = ImageClassifier( + backbone="resnet18", + training_strategy=training_strategy, + training_strategy_kwargs={"ways": dm.num_classes, "shots": 4, "meta_batch_size": 4}, + ) + + trainer = Trainer(fast_dev_run=2, gpus=gpus, accelerator=accelerator) + trainer.fit(model, datamodule=dm) + + +# 'metaoptnet' is not yet supported as it requires qpth as a dependency. +@pytest.mark.parametrize("training_strategy", ["anil", "maml", "prototypicalnetworks"]) +@pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") +def test_learn2learn_training_strategies(training_strategy, tmpdir): + _test_learn2learning_training_strategies(0, None, training_strategy, tmpdir) + + +@pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") +def test_wrongly_specified_training_strategies(): + with pytest.raises(KeyError, match="something is not in FlashRegistry"): + ImageClassifier( + backbone="resnet18", + training_strategy="something", + training_strategy_kwargs={"ways": 2, "shots": 4, "meta_batch_size": 10}, + ) + + +@pytest.mark.skipif(not os.getenv("FLASH_RUNNING_SPECIAL_TESTS", "0") == "1", reason="Should run with special test") +@pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") +def test_learn2learn_training_strategies_ddp(tmpdir): + _test_learn2learning_training_strategies(2, "ddp", "prototypicalnetworks", tmpdir) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 7782cb4409..3893cdc242 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -93,8 +93,8 @@ def test_init(): def test_training(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) - dl = model.process_train_dataset(ds, 2, 0, False, None) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) trainer.fit(model, dl) diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index e823212ef7..f17287dafd 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -11,36 +11,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import re import pytest import torch -from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.image import ImageEmbedder -from tests.helpers.utils import _IMAGE_TESTING +import flash +from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.image import ImageClassificationData, ImageEmbedder +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import FakeData +else: + FakeData = object -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) -def test_jit(tmpdir, jitter, args): - path = os.path.join(tmpdir, "test.pt") - - model = ImageEmbedder(embedding_dim=128) - model.eval() - - model = jitter(model, *args) - - torch.jit.save(model, path) - model = torch.jit.load(path) - - out = model(torch.rand(1, 3, 32, 32)) - assert isinstance(out, torch.Tensor) - assert out.shape == torch.Size([1, 128]) +# TODO: Figure out why VISSL can't be jitted +# @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +# @pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 64, 64),))]) +# def test_jit(tmpdir, jitter, args): +# path = os.path.join(tmpdir, "test.pt") +# +# model = ImageEmbedder(training_strategy="barlow_twins") +# model.eval() +# +# model = jitter(model, *args) +# +# torch.jit.save(model, path) +# model = torch.jit.load(path) +# +# out = model(torch.rand(1, 3, 64, 64)) +# assert isinstance(out, torch.Tensor) +# assert out.shape == torch.Size([1, 2048]) @pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.parametrize("backbone, training_strategy", [("resnet", "barlow_twins")]) +def test_vissl_training(tmpdir, backbone, training_strategy): + datamodule = ImageClassificationData.from_datasets( + train_dataset=FakeData(), + batch_size=4, + ) + + embedder = ImageEmbedder( + backbone=backbone, + training_strategy=training_strategy, + head="simclr_head", + pretraining_transform="barlow_twins_transform", + training_strategy_kwargs={"latent_embedding_dim": 128}, + pretraining_transform_kwargs={ + "total_num_crops": 2, + "num_crops": [2], + "size_crops": [96], + "crop_scales": [[0.4, 1]], + }, + ) + + trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count()) + trainer.fit(embedder, datamodule=datamodule) diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py new file mode 100644 index 0000000000..9f04e4a4c7 --- /dev/null +++ b/tests/image/embedding/utils.py @@ -0,0 +1,44 @@ +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import DefaultPreprocess +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.image import ImageClassificationData +from flash.image.embedding.vissl.transforms import multicrop_collate_fn + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import FakeData + +if _VISSL_AVAILABLE: + from classy_vision.dataset.transforms import TRANSFORM_REGISTRY + + +def ssl_datamodule( + batch_size=2, + total_num_crops=4, + num_crops=[2, 2], + size_crops=[160, 96], + crop_scales=[[0.4, 1], [0.05, 0.4]], + collate_fn=multicrop_collate_fn, +): + multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( + total_num_crops, num_crops, size_crops, crop_scales + ) + + to_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + multi_crop_transform, + ) + preprocess = DefaultPreprocess( + train_transform={ + "to_tensor_transform": to_tensor_transform, + "collate": collate_fn, + } + ) + + datamodule = ImageClassificationData.from_datasets( + train_dataset=FakeData(), + preprocess=preprocess, + batch_size=batch_size, + ) + + return datamodule diff --git a/tests/special_tests.sh b/tests/special_tests.sh new file mode 100644 index 0000000000..99cac8929a --- /dev/null +++ b/tests/special_tests.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +# this environment variable allows special tests to run +export FLASH_RUNNING_SPECIAL_TESTS=1 +# python arguments +defaults='-m coverage run --source flash --append -m pytest --durations=0 --capture=no --disable-warnings' + +# find tests marked as `@RunIf(special=True)` +grep_output=$(grep --recursive --line-number --word-regexp 'tests' --regexp 'os.getenv("FLASH_RUNNING_SPECIAL_TESTS",') +# file paths +files=$(echo "$grep_output" | cut -f1 -d:) +files_arr=($files) +echo $files + +# line numbers +linenos=$(echo "$grep_output" | cut -f2 -d:) +linenos_arr=($linenos) + +# tests to skip - space separated +blocklist='test_pytorch_profiler_nested_emit_nvtx' +report='' + +for i in "${!files_arr[@]}"; do + file=${files_arr[$i]} + lineno=${linenos_arr[$i]} + + # get code from `@RunIf(special=True)` line to EOF + test_code=$(tail -n +"$lineno" "$file") + + # read line by line + while read -r line; do + # if it's a test + if [[ $line == def\ test_* ]]; then + # get the name + test_name=$(echo $line | cut -c 5- | cut -f1 -d\() + + # check blocklist + if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then + report+="Skipped\t$file:$lineno::$test_name\n" + break + fi + + # SPECIAL_PATTERN allows filtering the tests to run when debugging. + # use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those + # test with `foo_bar` in their name + if [[ $line != *$SPECIAL_PATTERN* ]]; then + report+="Skipped\t$file:$lineno::$test_name\n" + break + fi + + # run the test + report+="Ran\t$file:$lineno::$test_name\n" + python ${defaults} "${file}::${test_name}" + break + fi + done < <(echo "$test_code") +done + +# echo test report +printf '=%.s' {1..80} +printf "\n$report" +printf '=%.s' {1..80} +printf '\n' diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 4c42909b35..238f419522 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -14,12 +14,14 @@ import os from pathlib import Path +import pandas as pd import pytest from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassificationData from flash.text.classification.data import ( TextCSVDataSource, + TextDataFrameDataSource, TextDataSource, TextFileDataSource, TextJSONDataSource, @@ -51,6 +53,11 @@ """ +TEST_DATA_FRAME_DATA = pd.DataFrame( + {"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]}, +) + + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" path.write_text(TEST_CSV_DATA) @@ -123,6 +130,17 @@ def test_from_json_with_field(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_data_frame(): + dm = TextClassificationData.from_data_frame( + "sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): @@ -138,6 +156,7 @@ def test_text_module_not_found_error(): (TextFileDataSource, {"filetype": "csv"}), (TextCSVDataSource, {}), (TextJSONDataSource, {}), + (TextDataFrameDataSource, {}), (TextSentencesDataSource, {}), ], )