diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index c05969cc5d..f4f2b596e7 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -17,7 +17,7 @@ from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the data datamodule = ImageClassificationData.from_folders( diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index 9f7b3e9fd7..e8ac6d8fcf 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -16,7 +16,7 @@ from flash.text import SummarizationData, SummarizationTask # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data datamodule = SummarizationData.from_files( diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index d4b7419abc..e9769296d3 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -1,7 +1,3 @@ -from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall - -import flash -from flash.core.data import download_data # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall + +import flash +from flash.core.data import download_data from flash.tabular import TabularClassifier, TabularData # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the data datamodule = TabularData.from_csv( diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 39fde33e45..4b5155b62d 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -1,5 +1,3 @@ -import flash -from flash.core.data import download_data # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import flash +from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the data datamodule = TextClassificationData.from_files( @@ -35,7 +35,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model -trainer.finetune(model, datamodule=datamodule, strategy='freeze') +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Test model trainer.test() diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index 6035cb6aaa..d7a4c043eb 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -1,5 +1,3 @@ -import flash -from flash import download_data # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import flash +from flash import download_data from flash.text import TranslationData, TranslationTask # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the data datamodule = TranslationData.from_files( diff --git a/flash_examples/predict/classify_image.py b/flash_examples/predict/classify_image.py index 83040a3b98..82b21b588b 100644 --- a/flash_examples/predict/classify_image.py +++ b/flash_examples/predict/classify_image.py @@ -16,7 +16,7 @@ from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") diff --git a/flash_examples/predict/classify_tabular.py b/flash_examples/predict/classify_tabular.py index 76777e67ab..cb2772361f 100644 --- a/flash_examples/predict/classify_tabular.py +++ b/flash_examples/predict/classify_tabular.py @@ -15,7 +15,7 @@ from flash.tabular import TabularClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the model from a checkpoint model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt") diff --git a/flash_examples/predict/classify_text.py b/flash_examples/predict/classify_text.py index 9790f53941..9b4a74d30a 100644 --- a/flash_examples/predict/classify_text.py +++ b/flash_examples/predict/classify_text.py @@ -17,7 +17,7 @@ from flash.text import TextClassificationData, TextClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 6f1b0026b9..68cf7b4e78 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -17,14 +17,14 @@ from flash.vision import ImageEmbedder # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Create an ImageEmbedder with swav trained on imagenet. # Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) # 3. Generate an embedding from an image path. -embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') +embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) # 4. Print embeddings shape print(embeddings.shape) diff --git a/flash_examples/predict/summarize.py b/flash_examples/predict/summarize.py index 05d3a8984e..172a7e67da 100644 --- a/flash_examples/predict/summarize.py +++ b/flash_examples/predict/summarize.py @@ -17,7 +17,7 @@ from flash.text import SummarizationData, SummarizationTask # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the model from a checkpoint model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") diff --git a/flash_examples/predict/translate.py b/flash_examples/predict/translate.py index 819ca2c4bf..a956a4af5a 100644 --- a/flash_examples/predict/translate.py +++ b/flash_examples/predict/translate.py @@ -17,7 +17,7 @@ from flash.text import TranslationData, TranslationTask # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the model from a checkpoint model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 79477611f1..68ff6d27b6 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -53,13 +53,20 @@ def run_test(filepath): "step,file", [ ("finetuning", "image_classification.py"), + # ("finetuning", "object_detection.py"), # TODO: takes too long. + # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), + ("finetuning", "text_classification.py"), + # ("finetuning", "translation.py"), # TODO: takes too long. ("predict", "classify_image.py"), ("predict", "classify_tabular.py"), - # "classify_text.py" TODO: takes too long + ("predict", "classify_text.py"), + ("predict", "image_embedder.py"), + ("predict", "summarize.py"), + # ("predict", "translate.py"), # TODO: takes too long ] ) -def test_finetune_example(tmpdir, step, file): +def test_example(tmpdir, step, file): run_test(str(root / "flash_examples" / step / file))