Skip to content

Commit

Permalink
Use register_buffer to save prototype vectors and add a progress bar
Browse files Browse the repository at this point in the history
during preparing prototype.
  • Loading branch information
mzr1996 committed Oct 13, 2022
1 parent 5641075 commit b8094b4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
1 change: 0 additions & 1 deletion mmcls/models/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.prototype = prototype
self.prototype_inited = False
self.prototype_vecs = None

@abstractmethod
def forward(self,
Expand Down
13 changes: 7 additions & 6 deletions mmcls/models/retrievers/image2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import track_on_main_process
from .base import BaseRetriever


Expand Down Expand Up @@ -86,7 +87,6 @@ def __init__(self,

self.prototype = prototype
self.prototype_inited = False
self.prototype_vecs = None
self.topk = topk

@property
Expand Down Expand Up @@ -248,7 +248,7 @@ def _get_prototype_vecs_from_dataloader(self):
num = len(data_loader.dataset)

prototype_vecs = None
for data_batch in data_loader:
for data_batch in track_on_main_process(data_loader):
data = self.data_preprocessor(data_batch, False)
feat = self(**data)
if isinstance(feat, tuple):
Expand Down Expand Up @@ -278,16 +278,17 @@ def prepare_prototype(self):
"""
device = next(self.image_encoder.parameters()).device
if isinstance(self.prototype, torch.Tensor):
self.prototype_vecs = self.prototype
prototype_vecs = self.prototype
elif isinstance(self.prototype, str):
self.prototype_vecs = torch.load(self.prototype)
prototype_vecs = torch.load(self.prototype)
elif isinstance(self.prototype, dict):
self.prototype = Runner.build_dataloader(self.prototype)

if isinstance(self.prototype, DataLoader):
self.prototype_vecs = self._get_prototype_vecs_from_dataloader()
prototype_vecs = self._get_prototype_vecs_from_dataloader()

self.prototype_vecs = self.prototype_vecs.to(device)
self.register_buffer(
'prototype_vecs', prototype_vecs.to(device), persistent=False)
self.prototype_inited = True

def dump_prototype(self, path):
Expand Down
3 changes: 2 additions & 1 deletion mmcls/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .progress import track_on_main_process
from .setup_env import register_all_modules

__all__ = ['collect_env', 'register_all_modules']
__all__ = ['collect_env', 'register_all_modules', 'track_on_main_process']
10 changes: 10 additions & 0 deletions mmcls/utils/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.dist as dist
import rich.progress as progress


def track_on_main_process(sequence, *args, **kwargs):
if not dist.is_main_process():
return sequence

yield from progress.track(sequence, *args, **kwargs)

0 comments on commit b8094b4

Please sign in to comment.