From 937122a94950e623100f1612ee2ea603ac9f2475 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Wed, 2 Nov 2022 16:21:41 +0800 Subject: [PATCH] modify the hook --- mmcls/engine/hooks/__init__.py | 4 ++-- mmcls/engine/hooks/retriever_hooks.py | 16 ++++++++-------- .../test_hooks/test_retrievers_hooks.py | 13 +++++++------ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mmcls/engine/hooks/__init__.py b/mmcls/engine/hooks/__init__.py index 3be4a897774..eb95bb395fc 100644 --- a/mmcls/engine/hooks/__init__.py +++ b/mmcls/engine/hooks/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .class_num_check_hook import ClassNumCheckHook from .precise_bn_hook import PreciseBNHook -from .retriever_hooks import ResetPrototypeInitFlagHook +from .retriever_hooks import PrepareProtoBeforeValLoopHook from .visualization_hook import VisualizationHook __all__ = [ 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', - 'ResetPrototypeInitFlagHook' + 'PrepareProtoBeforeValLoopHook' ] diff --git a/mmcls/engine/hooks/retriever_hooks.py b/mmcls/engine/hooks/retriever_hooks.py index d7676746ec5..ed9b6f99434 100644 --- a/mmcls/engine/hooks/retriever_hooks.py +++ b/mmcls/engine/hooks/retriever_hooks.py @@ -8,19 +8,19 @@ @HOOKS.register_module() -class ResetPrototypeInitFlagHook(Hook): - """The hook to reset the prototype's initialization flag in retrievers. +class PrepareProtoBeforeValLoopHook(Hook): + """The hook to prepare the prototype in retrievers. Since the encoders of the retriever changes during training, the prototype - changes accordingly. So the `prototype_inited` needs to be set to False - before validation. + changes accordingly. So the `prototype_vecs` needs to be regenerated before + validation loop. """ def before_val(self, runner) -> None: if isinstance(runner.model, BaseRetriever): - if hasattr(runner.model, 'prototype_inited'): - runner.model.prototype_inited = False + if hasattr(runner.model, 'prepare_prototype'): + runner.model.prepare_prototype() else: warnings.warn( - 'Only the retriever can execute `ResetPrototypeInitFlagHook`,' - f'but got {type(runner.model)}') + 'Only the retrievers can execute PrepareRetrieverPrototypeHook' + f', but got {type(runner.model)}') diff --git a/tests/test_engine/test_hooks/test_retrievers_hooks.py b/tests/test_engine/test_hooks/test_retrievers_hooks.py index 8b6082d78c2..055803f10ac 100644 --- a/tests/test_engine/test_hooks/test_retrievers_hooks.py +++ b/tests/test_engine/test_hooks/test_retrievers_hooks.py @@ -4,14 +4,14 @@ import torch -from mmcls.engine import ResetPrototypeInitFlagHook +from mmcls.engine import PrepareProtoBeforeValLoopHook from mmcls.models.retrievers import BaseRetriever class ToyRetriever(BaseRetriever): def forward(self, inputs, data_samples=None, mode: str = 'loss'): - pass + self.prototype_inited is False def prepare_prototype(self): """Preprocessing the prototype before predict.""" @@ -19,15 +19,16 @@ def prepare_prototype(self): self.prototype_inited = True -class TestClassNumCheckHook(TestCase): +class TestPrepareProtBeforeValLoopHook(TestCase): def setUp(self): - self.hook = ResetPrototypeInitFlagHook() + self.hook = PrepareProtoBeforeValLoopHook self.runner = MagicMock() self.runner.model = ToyRetriever() def test_before_val(self): self.runner.model.prepare_prototype() self.assertTrue(self.runner.model.prototype_inited) - self.hook.before_val(self.runner) - self.assertFalse(self.runner.model.prototype_inited) + self.hook.before_val(self, self.runner) + self.assertIsNotNone(self.runner.model.prototype_vecs) + self.assertTrue(self.runner.model.prototype_inited)