Skip to content

Commit

Permalink
modify the hook
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezra-Yu committed Nov 2, 2022
1 parent efe2f1a commit 937122a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions mmcls/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
from .precise_bn_hook import PreciseBNHook
from .retriever_hooks import ResetPrototypeInitFlagHook
from .retriever_hooks import PrepareProtoBeforeValLoopHook
from .visualization_hook import VisualizationHook

__all__ = [
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
'ResetPrototypeInitFlagHook'
'PrepareProtoBeforeValLoopHook'
]
16 changes: 8 additions & 8 deletions mmcls/engine/hooks/retriever_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@


@HOOKS.register_module()
class ResetPrototypeInitFlagHook(Hook):
"""The hook to reset the prototype's initialization flag in retrievers.
class PrepareProtoBeforeValLoopHook(Hook):
"""The hook to prepare the prototype in retrievers.
Since the encoders of the retriever changes during training, the prototype
changes accordingly. So the `prototype_inited` needs to be set to False
before validation.
changes accordingly. So the `prototype_vecs` needs to be regenerated before
validation loop.
"""

def before_val(self, runner) -> None:
if isinstance(runner.model, BaseRetriever):
if hasattr(runner.model, 'prototype_inited'):
runner.model.prototype_inited = False
if hasattr(runner.model, 'prepare_prototype'):
runner.model.prepare_prototype()
else:
warnings.warn(
'Only the retriever can execute `ResetPrototypeInitFlagHook`,'
f'but got {type(runner.model)}')
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
f', but got {type(runner.model)}')
13 changes: 7 additions & 6 deletions tests/test_engine/test_hooks/test_retrievers_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@

import torch

from mmcls.engine import ResetPrototypeInitFlagHook
from mmcls.engine import PrepareProtoBeforeValLoopHook
from mmcls.models.retrievers import BaseRetriever


class ToyRetriever(BaseRetriever):

def forward(self, inputs, data_samples=None, mode: str = 'loss'):
pass
self.prototype_inited is False

def prepare_prototype(self):
"""Preprocessing the prototype before predict."""
self.prototype_vecs = torch.tensor([0])
self.prototype_inited = True


class TestClassNumCheckHook(TestCase):
class TestPrepareProtBeforeValLoopHook(TestCase):

def setUp(self):
self.hook = ResetPrototypeInitFlagHook()
self.hook = PrepareProtoBeforeValLoopHook
self.runner = MagicMock()
self.runner.model = ToyRetriever()

def test_before_val(self):
self.runner.model.prepare_prototype()
self.assertTrue(self.runner.model.prototype_inited)
self.hook.before_val(self.runner)
self.assertFalse(self.runner.model.prototype_inited)
self.hook.before_val(self, self.runner)
self.assertIsNotNone(self.runner.model.prototype_vecs)
self.assertTrue(self.runner.model.prototype_inited)

0 comments on commit 937122a

Please sign in to comment.