diff --git a/CHANGELOG.md b/CHANGELOG.md index 201c120031..5a77d72ec0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed compatibility with `torchmetrics==1.10.0` ([#1469](https://github.com/Lightning-AI/lightning-flash/pull/1469)) + ## [0.8.0] - 2022-09-02 diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index a7d45ad566..9f067fb900 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -90,6 +90,7 @@ class Image: _PANDAS_GREATER_EQUAL_1_3_0 = compare_version("pandas", operator.ge, "1.3.0") _ICEVISION_GREATER_EQUAL_0_11_0 = compare_version("icevision", operator.ge, "0.11.0") _TM_GREATER_EQUAL_0_7_0 = compare_version("torchmetrics", operator.ge, "0.7.0") + _TM_GREATER_EQUAL_0_10_0 = compare_version("torchmetrics", operator.ge, "0.10.0") _BAAL_GREATER_EQUAL_1_5_2 = compare_version("baal", operator.ge, "1.5.2") _TEXT_AVAILABLE = all( diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 4ea7ac14e6..4d7d796c9c 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -25,6 +25,7 @@ from flash.core.serve import Composition from flash.core.utilities.imports import ( _TM_GREATER_EQUAL_0_7_0, + _TM_GREATER_EQUAL_0_10_0, _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_9, requires, @@ -55,7 +56,9 @@ class InterpolationMode: NEAREST = "nearest" -if _TM_GREATER_EQUAL_0_7_0: +if _TM_GREATER_EQUAL_0_10_0: + from torchmetrics.classification import MulticlassJaccardIndex as JaccardIndex +elif _TM_GREATER_EQUAL_0_7_0: from torchmetrics import JaccardIndex else: from torchmetrics import IoU as JaccardIndex diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index f68d412f7c..7fbcbc2460 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -24,7 +24,7 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.collate import wrap_collate from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, _TM_GREATER_EQUAL_0_10_0 from flash.core.utilities.stability import beta from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES @@ -33,7 +33,9 @@ from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label from open3d.ml.torch.dataloaders import TorchDataloader -if _TM_GREATER_EQUAL_0_7_0: +if _TM_GREATER_EQUAL_0_10_0: + from torchmetrics.classification import MulticlassJaccardIndex as JaccardIndex +elif _TM_GREATER_EQUAL_0_7_0: from torchmetrics import JaccardIndex else: from torchmetrics import IoU as JaccardIndex diff --git a/flash_examples/serve/semantic_segmentation/inference_server.py b/flash_examples/serve/semantic_segmentation/inference_server.py index d373f722c3..9fbfc21b46 100644 --- a/flash_examples/serve/semantic_segmentation/inference_server.py +++ b/flash_examples/serve/semantic_segmentation/inference_server.py @@ -15,7 +15,7 @@ from flash.image.segmentation.output import SegmentationLabelsOutput model = SemanticSegmentation.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.8.0/semantic_segmentation_model.pt" + "https://flash-weights.s3.amazonaws.com/0.9.0/semantic_segmentation_model.pt" ) model.output = SegmentationLabelsOutput(visualize=False) model.serve() diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 5b701fdf4f..65ec60052b 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -220,7 +220,7 @@ def test_classification_task_trainer_predict(tmpdir): ), pytest.param( SemanticSegmentation, - "0.8.0/semantic_segmentation_model.pt", + "0.9.0/semantic_segmentation_model.pt", marks=pytest.mark.skipif( not _IMAGE_TESTING, reason="image packages aren't installed",