diff --git a/src/otx/algo/common/layers/spp_layer.py b/src/otx/algo/common/layers/spp_layer.py index 3fd8c6e46a6..a2f5803d4ee 100644 --- a/src/otx/algo/common/layers/spp_layer.py +++ b/src/otx/algo/common/layers/spp_layer.py @@ -67,6 +67,6 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """Forward.""" x = self.conv1(x) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(x.device.type, enabled=False): x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1) return self.conv2(x) diff --git a/src/otx/algo/detection/layers/channel_attention_layer.py b/src/otx/algo/detection/layers/channel_attention_layer.py index c8c256b4e6f..a776e615276 100644 --- a/src/otx/algo/detection/layers/channel_attention_layer.py +++ b/src/otx/algo/detection/layers/channel_attention_layer.py @@ -33,7 +33,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """Forward function for ChannelAttention.""" - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(x.device.type, enabled=False): out = self.global_avgpool(x) out = self.fc(out) out = self.act(out) diff --git a/src/otx/algo/keypoint_detection/heads/rtmcc_head.py b/src/otx/algo/keypoint_detection/heads/rtmcc_head.py index 63f12d2edca..d6f0f7cd0d5 100644 --- a/src/otx/algo/keypoint_detection/heads/rtmcc_head.py +++ b/src/otx/algo/keypoint_detection/heads/rtmcc_head.py @@ -208,6 +208,8 @@ def to_numpy(self, x: Tensor | tuple[Tensor, Tensor]) -> np.ndarray | tuple[np.n np.ndarray | tuple: return a tuple of converted numpy array(s) """ if isinstance(x, Tensor): + if x.dtype == torch.bfloat16: + x = x.float() return x.detach().cpu().numpy() if isinstance(x, tuple) and all(isinstance(i, Tensor) for i in x): return tuple([self.to_numpy(i) for i in x]) diff --git a/src/otx/core/metrics/pck.py b/src/otx/core/metrics/pck.py index 941ed679775..966d56d841f 100644 --- a/src/otx/core/metrics/pck.py +++ b/src/otx/core/metrics/pck.py @@ -25,7 +25,7 @@ def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, norm_f Args: preds (np.ndarray[N, K, D]): Predicted keypoint location. - gts (np.ndarray[N, K, D]): Groundtruth keypoint location. + gts (np.ndarray[N, K, D]): Ground truth keypoint location. mask (np.ndarray[N, K]): Visibility of the target. False for invisible joints, and True for visible. Invisible joints will be ignored for accuracy calculation. @@ -75,7 +75,7 @@ def keypoint_pck_accuracy( pred: np.ndarray, gt: np.ndarray, mask: np.ndarray, - thr: np.ndarray, + thr: float, norm_factor: np.ndarray, ) -> tuple[np.ndarray, float, int]: """Calculate the pose accuracy of PCK for each individual keypoint. @@ -99,7 +99,7 @@ def keypoint_pck_accuracy( joints, and True for visible. Invisible joints will be ignored for accuracy calculation. thr (float): Threshold of PCK calculation. - norm_factor (np.ndarray[N, 2]): Normalization factor for H&W. + norm_factor (np.ndarray[N, 2]): Normalization factor for the keypoints. Returns: tuple: A tuple containing keypoint accuracy. @@ -117,34 +117,22 @@ def keypoint_pck_accuracy( class PCKMeasure(Metric): - """Computes the f-measure (also known as F1-score) for a resultset. - - The f-measure is typically used in detection (localization) tasks to obtain a single number that balances precision - and recall. - - To determine whether a predicted box matches a ground truth box an overlap measured - is used based on a minimum - intersection-over-union (IoU), by default a value of 0.5 is used. - - In addition spurious results are eliminated by applying non-max suppression (NMS) so that two predicted boxes with - IoU > threshold are reduced to one. This threshold can be determined automatically by setting `vary_nms_threshold` - to True. + """Computes the pose accuracy (also known as PCK) for a resultset. Args: label_info (int): Dataclass including label information. - vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold - values. Defaults to False. - cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate - boxes with sufficient overlap even if they are from different classes. Defaults to False. + dist_threshold (float): Threshold of PCK calculation. """ def __init__( self, label_info: LabelInfo, + dist_threshold: float = 0.05, ): super().__init__() self.label_info: LabelInfo = label_info + self.dist_threshold: float = dist_threshold self.reset() @property @@ -190,19 +178,33 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]] def compute(self) -> dict: """Compute PCK score metric.""" pred_kpts = np.stack([p[0].cpu().numpy() for p in self.preds]) - gt_kpts = np.stack([p[0] for p in self.targets]) - kpts_visible = np.stack([p[1] for p in self.targets]) - - normalize = np.tile(np.array([self.input_size]), (pred_kpts.shape[0], 1)) + gt_kpts_processed = [] + for p in self.targets: + if len(p[0].shape) == 3 and p[0].shape[0] == 1: + gt_kpts_processed.append(p[0].squeeze()) + else: + gt_kpts_processed.append(p[0]) + gt_kpts = np.stack(gt_kpts_processed) + + kpts_visible = [] + for p in self.targets: + if len(p[1].shape) == 3 and p[1].shape[0] == 1: + kpts_visible.append(p[1].squeeze()) + else: + kpts_visible.append(p[1]) + + kpts_visible_stacked = np.stack(kpts_visible) + + normalize = np.tile(np.array([self.input_size[::-1]]), (pred_kpts.shape[0], 1)) _, avg_acc, _ = keypoint_pck_accuracy( pred_kpts, gt_kpts, - mask=kpts_visible > 0, - thr=0.05, + mask=kpts_visible_stacked > 0, + thr=self.dist_threshold, norm_factor=normalize, ) - return {"accuracy": Tensor([avg_acc])} + return {"PCK": Tensor([avg_acc])} def _pck_measure_callable(label_info: LabelInfo) -> PCKMeasure: diff --git a/src/otx/core/model/keypoint_detection.py b/src/otx/core/model/keypoint_detection.py index 5f6377c28fa..ae856d5c702 100644 --- a/src/otx/core/model/keypoint_detection.py +++ b/src/otx/core/model/keypoint_detection.py @@ -15,18 +15,19 @@ from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity from otx.core.metrics import MetricCallable, MetricInput from otx.core.metrics.pck import PCKMeasureCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel +from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import LabelInfoTypes if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + from model_api.models.utils import DetectedKeypoints from torch import nn class OTXKeypointDetectionModel(OTXModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]): - """Base class for the detection models used in OTX.""" + """Base class for the keypoint detection models used in OTX.""" def __init__( self, @@ -37,8 +38,6 @@ def __init__( metric: MetricCallable = PCKMeasureCallable, torch_compile: bool = False, ) -> None: - self.mean = (0.0, 0.0, 0.0) - self.std = (255.0, 255.0, 255.0) super().__init__( label_info=label_info, input_size=input_size, @@ -157,6 +156,97 @@ def _export_parameters(self) -> TaskLevelExportParameters: model_type="keypoint_detection", task_type="keypoint_detection", confidence_threshold=self.hparams.get("best_confidence_threshold", None), - iou_threshold=0.5, - tile_config=self.tile_config if self.tile_config.enable_tiler else None, ) + + def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity: + """Returns a dummy input for key point detection model.""" + images = torch.rand(batch_size, 3, *self.input_size) + return KeypointDetBatchDataEntity( + batch_size, + images, + [], + [torch.tensor([0, 0, self.input_size[1], self.input_size[0]])], + labels=[], + bbox_info=[], + keypoints=[], + keypoints_visible=[], + ) + + +class OVKeypointDetectionModel(OVModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]): + """Keypoint detection model compatible for OpenVINO IR inference. + + It can consume OpenVINO IR model path or model name from Intel OMZ repository + and create the OTX keypoint detection model compatible for OTX testing pipeline. + """ + + def __init__( + self, + model_name: str, + model_type: str = "keypoint_detection", + async_inference: bool = True, + max_num_requests: int | None = None, + use_throughput_mode: bool = True, + model_api_configuration: dict[str, Any] | None = None, + metric: MetricCallable = PCKMeasureCallable, + **kwargs, + ) -> None: + super().__init__( + model_name=model_name, + model_type=model_type, + async_inference=async_inference, + max_num_requests=max_num_requests, + use_throughput_mode=use_throughput_mode, + model_api_configuration=model_api_configuration, + metric=metric, + ) + + def _customize_outputs( + self, + outputs: list[DetectedKeypoints], + inputs: KeypointDetBatchDataEntity, + ) -> KeypointDetBatchPredEntity | OTXBatchLossEntity: + keypoints = [] + scores = [] + for output in outputs: + keypoints.append(torch.as_tensor(output.keypoints, device=self.device)) + scores.append(torch.as_tensor(output.scores, device=self.device)) + + return KeypointDetBatchPredEntity( + batch_size=len(outputs), + images=inputs.images, + imgs_info=inputs.imgs_info, + keypoints=keypoints, + scores=scores, + keypoints_visible=[], + bboxes=[], + labels=[], + bbox_info=[], + ) + + def configure_metric(self) -> None: + """Configure the metric.""" + super().configure_metric() + self._metric.input_size = (self.model.h, self.model.w) + + def _convert_pred_entity_to_compute_metric( + self, + preds: KeypointDetBatchPredEntity, + inputs: KeypointDetBatchDataEntity, + ) -> MetricInput: + return { + "preds": [ + { + "keypoints": kpt, + "scores": score, + } + for kpt, score in zip(preds.keypoints, preds.scores) + ], + "target": [ + { + "keypoints": kpt, + "keypoints_visible": kpt_visible, + } + for kpt, kpt_visible in zip(inputs.keypoints, inputs.keypoints_visible) + ], + } diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/engine/utils/auto_configurator.py index cc9dbdfdd52..16fb530610c 100644 --- a/src/otx/engine/utils/auto_configurator.py +++ b/src/otx/engine/utils/auto_configurator.py @@ -90,6 +90,7 @@ OTXTaskType.ANOMALY_CLASSIFICATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO", OTXTaskType.ANOMALY_DETECTION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO", OTXTaskType.ANOMALY_SEGMENTATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO", + OTXTaskType.KEYPOINT_DETECTION: "otx.core.model.keypoint_detection.OVKeypointDetectionModel", } diff --git a/src/otx/recipe/keypoint_detection/openvino_model.yaml b/src/otx/recipe/keypoint_detection/openvino_model.yaml new file mode 100644 index 00000000000..040bf28ba69 --- /dev/null +++ b/src/otx/recipe/keypoint_detection/openvino_model.yaml @@ -0,0 +1,50 @@ +model: + class_path: otx.core.model.keypoint_detection.OVKeypointDetectionModel + init_args: + label_info: 19 + model_name: rtm_pose_tiny + model_type: "keypoint_detection" + async_inference: true + use_throughput_mode: true + +engine: + task: KEYPOINT_DETECTION + device: cpu + +callback_monitor: val/PCK + +data: ../_base_/data/keypoint_detection.yaml +overrides: + reset: + - data.train_subset.transforms + - data.val_subset.transforms + - data.test_subset.transforms + + data: + stack_images: false + train_subset: + batch_size: 1 + num_workers: 2 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale + - class_path: otx.core.data.transform_libs.torchvision.TopdownAffine + init_args: + input_size: $(input_size) + + val_subset: + batch_size: 1 + num_workers: 2 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale + - class_path: otx.core.data.transform_libs.torchvision.TopdownAffine + init_args: + input_size: $(input_size) + + test_subset: + batch_size: 64 + num_workers: 2 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale + - class_path: otx.core.data.transform_libs.torchvision.TopdownAffine + init_args: + input_size: $(input_size) diff --git a/src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml b/src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml index 447d4fd5218..0b63845dab3 100644 --- a/src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml +++ b/src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml @@ -19,13 +19,13 @@ model: mode: max factor: 0.1 patience: 9 - monitor: val/accuracy + monitor: val/PCK engine: task: KEYPOINT_DETECTION device: auto -callback_monitor: val/accuracy +callback_monitor: val/PCK data: ../_base_/data/keypoint_detection.yaml diff --git a/src/otx/recipe/keypoint_detection/rtmpose_tiny_single_obj.yaml b/src/otx/recipe/keypoint_detection/rtmpose_tiny_single_obj.yaml index 8045bb5e85c..82ba01a4c96 100644 --- a/src/otx/recipe/keypoint_detection/rtmpose_tiny_single_obj.yaml +++ b/src/otx/recipe/keypoint_detection/rtmpose_tiny_single_obj.yaml @@ -22,13 +22,13 @@ model: mode: max factor: 0.1 patience: 9 - monitor: val/accuracy + monitor: val/PCK engine: task: KEYPOINT_DETECTION device: auto -callback_monitor: val/accuracy +callback_monitor: val/PCK data: ../_base_/data/keypoint_detection.yaml diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index c7d1b32823f..7c3c9d2a959 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -257,6 +257,9 @@ def test_otx_e2e( if "yolov9" in model_name: return # RT-DETR currently is not supported. + if "keypoint" in recipe: + print("Explain is not supported for keypoint detection") + return tmp_path_test = tmp_path / f"otx_export_xai_{model_name}" for export_case in fxt_export_list: diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 8071925aadb..f39547ca81a 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -48,6 +48,7 @@ def fxt_local_seed() -> int: "visual_prompting": "test/f1-score", "zero_shot_visual_prompting": "test/f1-score", "action_classification": "test/accuracy", + "keypoint_detection": "test/PCK", } diff --git a/tests/unit/core/metrics/pck.py b/tests/unit/core/metrics/pck.py new file mode 100644 index 00000000000..07b7a887261 --- /dev/null +++ b/tests/unit/core/metrics/pck.py @@ -0,0 +1,50 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests for PCK metric.""" + +from __future__ import annotations + +import pytest +import torch +from otx.core.metrics.pck import PCKMeasure +from otx.core.types.label import LabelInfo + + +class TestPCK: + @pytest.fixture() + def fxt_preds(self) -> list[dict[str, torch.Tensor]]: + return [ + { + "keypoints": torch.Tensor([[0.7, 0.6], [0.9, 0.6]]), + "scores": torch.Tensor([0.9, 0.8]), + }, + { + "keypoints": torch.Tensor([[0.3, 0.4], [0.6, 0.6]]), + "scores": torch.Tensor([0.9, 0.8]), + }, + ] + + @pytest.fixture() + def fxt_targets(self) -> list[dict[str, torch.Tensor]]: + return [ + { + "keypoints": torch.Tensor([[0.3, 0.4], [0.6, 0.6]]), + "keypoints_visible": torch.Tensor([0.9, 0.8]), + }, + { + "keypoints": torch.Tensor([[0.7, 0.6], [0.9, 0.6]]), + "keypoints_visible": torch.Tensor([0.9, 0.8]), + }, + ] + + def test_pck(self, fxt_preds, fxt_targets) -> None: + metric = PCKMeasure(label_info=LabelInfo.from_num_classes(1)) + metric.input_size = (1, 1) + metric.update(fxt_preds, fxt_targets) + result = metric.compute() + assert result["PCK"] == 0 + + metric.reset() + assert metric.preds == [] + assert metric.targets == [] diff --git a/tests/unit/core/model/test_keypoint_detection.py b/tests/unit/core/model/test_keypoint_detection.py new file mode 100644 index 00000000000..d3cc06fede7 --- /dev/null +++ b/tests/unit/core/model/test_keypoint_detection.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests for keypoint detection model entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import torch +from otx.algo.keypoint_detection.rtmpose import RTMPoseTiny +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity +from otx.core.metrics.pck import PCKMeasureCallable +from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable +from otx.core.types.label import LabelInfo + +if TYPE_CHECKING: + from otx.core.model.keypoint_detection import OTXKeypointDetectionModel + + +class TestOTXKeypointDetectionModel: + @pytest.fixture() + def model(self, label_info, optimizer, scheduler, metric, torch_compile) -> OTXKeypointDetectionModel: + return RTMPoseTiny(label_info, (512, 512), optimizer, scheduler, metric, torch_compile) + + @pytest.fixture() + def batch_data_entity(self, model) -> KeypointDetBatchDataEntity: + return model.get_dummy_input(2) + + @pytest.fixture() + def label_info(self) -> LabelInfo: + return LabelInfo( + label_names=["label_0", "label_1"], + label_groups=[["label_0", "label_1"]], + ) + + @pytest.fixture() + def optimizer(self): + return DefaultOptimizerCallable + + @pytest.fixture() + def scheduler(self): + return DefaultSchedulerCallable + + @pytest.fixture() + def metric(self): + return PCKMeasureCallable + + @pytest.fixture() + def torch_compile(self): + return False + + def test_export_parameters(self, model): + params = model._export_parameters + assert params.model_type == "keypoint_detection" + assert params.task_type == "keypoint_detection" + + @pytest.mark.parametrize( + ("label_info", "expected_label_info"), + [ + ( + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + ), + (LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)), + ], + ) + def test_dispatch_label_info(self, model, label_info, expected_label_info): + result = model._dispatch_label_info(label_info) + assert result == expected_label_info + + def test_init(self, model): + assert model.num_classes == 2 + + def test_customize_inputs(self, model, batch_data_entity): + customized_inputs = model._customize_inputs(batch_data_entity) + assert customized_inputs["inputs"].shape == (2, 3, model.input_size[0], model.input_size[1]) + assert "mode" in customized_inputs + + def test_customize_outputs_training(self, model, batch_data_entity): + outputs = {"loss": torch.tensor(0.5)} + customized_outputs = model._customize_outputs(outputs, batch_data_entity) + assert isinstance(customized_outputs, OTXBatchLossEntity) + assert customized_outputs["loss"] == torch.tensor(0.5) + + def test_customize_outputs_predict(self, model, batch_data_entity): + model.training = False + outputs = [(torch.randn(2, 2, 2), torch.randn(2, 2, 2))] + customized_outputs = model._customize_outputs(outputs, batch_data_entity) + assert isinstance(customized_outputs, KeypointDetBatchPredEntity) + assert len(customized_outputs.keypoints) == len(customized_outputs.scores) + + def test_dummy_input(self, model: OTXKeypointDetectionModel): + batch_size = 2 + batch = model.get_dummy_input(batch_size) + assert batch.batch_size == batch_size