Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Enable tests for examples (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Feb 17, 2021
1 parent 24c5b66 commit a6edeab
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 23 deletions.
2 changes: 1 addition & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flash.text import SummarizationData, SummarizationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")

# 2. Load the data
datamodule = SummarizationData.from_files(
Expand Down
10 changes: 5 additions & 5 deletions flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
from flash.tabular import TabularClassifier, TabularData

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

# 2. Load the data
datamodule = TabularData.from_csv(
Expand Down
8 changes: 4 additions & 4 deletions flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import flash
from flash.core.data import download_data
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the data
datamodule = TextClassificationData.from_files(
Expand All @@ -35,7 +35,7 @@
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Test model
trainer.test()
Expand Down
6 changes: 3 additions & 3 deletions flash_examples/finetuning/translation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import flash
from flash import download_data
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

# 2. Load the data
datamodule = TranslationData.from_files(
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/predict/classify_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/predict/classify_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flash.tabular import TabularClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/predict/classify_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
Expand Down
4 changes: 2 additions & 2 deletions flash_examples/predict/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from flash.vision import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Create an ImageEmbedder with swav trained on imagenet.
# Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)

# 3. Generate an embedding from an image path.
embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg')
embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"])

# 4. Print embeddings shape
print(embeddings.shape)
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/predict/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.text import SummarizationData, SummarizationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")

# 2. Load the model from a checkpoint
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/predict/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.text import TranslationData, TranslationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

# 2. Load the model from a checkpoint
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
Expand Down
11 changes: 9 additions & 2 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,20 @@ def run_test(filepath):
"step,file",
[
("finetuning", "image_classification.py"),
# ("finetuning", "object_detection.py"), # TODO: takes too long.
# ("finetuning", "summarization.py"), # TODO: takes too long.
("finetuning", "tabular_classification.py"),
("finetuning", "text_classification.py"),
# ("finetuning", "translation.py"), # TODO: takes too long.
("predict", "classify_image.py"),
("predict", "classify_tabular.py"),
# "classify_text.py" TODO: takes too long
("predict", "classify_text.py"),
("predict", "image_embedder.py"),
("predict", "summarize.py"),
# ("predict", "translate.py"), # TODO: takes too long
]
)
def test_finetune_example(tmpdir, step, file):
def test_example(tmpdir, step, file):
run_test(str(root / "flash_examples" / step / file))


Expand Down

0 comments on commit a6edeab

Please sign in to comment.