From 4a242e6e7d25e0ab9099531c1854a6a641bd3e4a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 7 Nov 2021 13:42:19 +0000 Subject: [PATCH] Bump pretrained weights to 0.6.0 (#940) --- docs/source/general/predictions.rst | 6 +++--- docs/source/quickstart.rst | 2 +- .../integrations/fiftyone/image_classification.py | 2 +- .../image_classification_fiftyone_datasets.py | 2 +- .../serve/image_classification/inference_server.py | 2 +- .../semantic_segmentation/inference_server.py | 2 +- .../serve/speech_recognition/inference_server.py | 2 +- .../serve/summarization/inference_server.py | 2 +- .../tabular_classification/inference_server.py | 2 +- .../serve/text_classification/inference_server.py | 2 +- .../serve/translation/inference_server.py | 2 +- flash_notebooks/tabular_classification.ipynb | 2 +- tests/core/test_model.py | 14 +++++++------- 13 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index 3181d9766e..da415045bd 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -24,7 +24,7 @@ You can pass in a sample of data (image file path, a string of text, etc) to the # 2. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) # 3. Predict whether the image contains an ant or a bee @@ -46,7 +46,7 @@ Predict on a csv file # 2. Load the model from a checkpoint model = TabularClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt" ) # 3. Generate predictions from a csv file! Who would survive? @@ -74,7 +74,7 @@ reference below). # 2. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) # 3. Attach the Output diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index c82eaa0faf..227c6dcaca 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -91,7 +91,7 @@ Here's an example of inference: from flash.text import TextClassifier # 1. Init the finetuned task from URL - model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt") + model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt") # 2. Perform inference from list of sequences predictions = model.predict( diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py index f7fc0db39a..0cce2c09f3 100644 --- a/flash_examples/integrations/fiftyone/image_classification.py +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -54,7 +54,7 @@ # 4 Predict from checkpoint model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) model.output = FiftyOneLabels(return_filepath=True) # output FiftyOne format predictions = trainer.predict(model, datamodule=datamodule) diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py index 96ea5ffc51..83077c6d6e 100644 --- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -67,7 +67,7 @@ # 5 Predict from checkpoint on data with ground truth model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) model.output = FiftyOneLabels(return_filepath=False) # output FiftyOne format datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset) diff --git a/flash_examples/serve/image_classification/inference_server.py b/flash_examples/serve/image_classification/inference_server.py index a20e147c97..00a519a63b 100644 --- a/flash_examples/serve/image_classification/inference_server.py +++ b/flash_examples/serve/image_classification/inference_server.py @@ -14,6 +14,6 @@ from flash.image import ImageClassifier model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt" ) model.serve() diff --git a/flash_examples/serve/semantic_segmentation/inference_server.py b/flash_examples/serve/semantic_segmentation/inference_server.py index ea106da239..ca42c43d68 100644 --- a/flash_examples/serve/semantic_segmentation/inference_server.py +++ b/flash_examples/serve/semantic_segmentation/inference_server.py @@ -15,7 +15,7 @@ from flash.image.segmentation.output import SegmentationLabels model = SemanticSegmentation.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/semantic_segmentation_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/semantic_segmentation_model.pt" ) model.output = SegmentationLabels(visualize=False) model.serve() diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py index 34e21ca319..9bc9a0ed8f 100644 --- a/flash_examples/serve/speech_recognition/inference_server.py +++ b/flash_examples/serve/speech_recognition/inference_server.py @@ -14,6 +14,6 @@ from flash.audio import SpeechRecognition model = SpeechRecognition.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/speech_recognition_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/speech_recognition_model.pt" ) model.serve() diff --git a/flash_examples/serve/summarization/inference_server.py b/flash_examples/serve/summarization/inference_server.py index 8dea17bd40..caa35bc7c4 100644 --- a/flash_examples/serve/summarization/inference_server.py +++ b/flash_examples/serve/summarization/inference_server.py @@ -14,6 +14,6 @@ from flash.text import SummarizationTask model = SummarizationTask.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/summarization_model_xsum.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/summarization_model_xsum.pt" ) model.serve() diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index c9365d773b..e92543f087 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -15,7 +15,7 @@ from flash.tabular import TabularClassifier model = TabularClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt" + "https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt" ) model.output = Labels(["Did not survive", "Survived"]) model.serve() diff --git a/flash_examples/serve/text_classification/inference_server.py b/flash_examples/serve/text_classification/inference_server.py index ad8ff098f0..a05462caa9 100644 --- a/flash_examples/serve/text_classification/inference_server.py +++ b/flash_examples/serve/text_classification/inference_server.py @@ -13,5 +13,5 @@ # limitations under the License. from flash.text import TextClassifier -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt") +model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt") model.serve() diff --git a/flash_examples/serve/translation/inference_server.py b/flash_examples/serve/translation/inference_server.py index 0c9ed2f894..406eb6883a 100644 --- a/flash_examples/serve/translation/inference_server.py +++ b/flash_examples/serve/translation/inference_server.py @@ -13,5 +13,5 @@ # limitations under the License. from flash.text import TranslationTask -model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/translation_model_en_ro.pt") +model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/translation_model_en_ro.pt") model.serve() diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index ac8ec85242..43d6ca454f 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -222,7 +222,7 @@ "outputs": [], "source": [ "model = TabularClassifier.load_from_checkpoint(\n", - " \"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt\")" + " \"https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt\")" ] }, { diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6a885fc3b6..9a465c3e10 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -250,7 +250,7 @@ def test_task_datapipeline_save(tmpdir): [ pytest.param( ImageClassifier, - "0.5.2/image_classification_model.pt", + "0.6.0/image_classification_model.pt", marks=pytest.mark.skipif( not _IMAGE_TESTING, reason="image packages aren't installed", @@ -258,7 +258,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( SemanticSegmentation, - "0.5.2/semantic_segmentation_model.pt", + "0.6.0/semantic_segmentation_model.pt", marks=pytest.mark.skipif( not _IMAGE_TESTING, reason="image packages aren't installed", @@ -266,7 +266,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( SpeechRecognition, - "0.5.2/speech_recognition_model.pt", + "0.6.0/speech_recognition_model.pt", marks=pytest.mark.skipif( not _AUDIO_TESTING, reason="audio packages aren't installed", @@ -274,7 +274,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( TabularClassifier, - "0.5.2/tabular_classification_model.pt", + "0.6.0/tabular_classification_model.pt", marks=pytest.mark.skipif( not _TABULAR_TESTING, reason="tabular packages aren't installed", @@ -282,7 +282,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( TextClassifier, - "0.5.2/text_classification_model.pt", + "0.6.0/text_classification_model.pt", marks=pytest.mark.skipif( not _TEXT_TESTING, reason="text packages aren't installed", @@ -290,7 +290,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( SummarizationTask, - "0.5.2/summarization_model_xsum.pt", + "0.6.0/summarization_model_xsum.pt", marks=pytest.mark.skipif( not _TEXT_TESTING, reason="text packages aren't installed", @@ -298,7 +298,7 @@ def test_task_datapipeline_save(tmpdir): ), pytest.param( TranslationTask, - "0.5.2/translation_model_en_ro.pt", + "0.6.0/translation_model_en_ro.pt", marks=pytest.mark.skipif( not _TEXT_TESTING, reason="text packages aren't installed",