From b8094b440024e05047ebb01115c9ccee03f2b495 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Thu, 13 Oct 2022 17:13:37 +0800 Subject: [PATCH] Use `register_buffer` to save prototype vectors and add a progress bar during preparing prototype. --- mmcls/models/retrievers/base.py | 1 - mmcls/models/retrievers/image2image.py | 13 +++++++------ mmcls/utils/__init__.py | 3 ++- mmcls/utils/progress.py | 10 ++++++++++ 4 files changed, 19 insertions(+), 8 deletions(-) create mode 100644 mmcls/utils/progress.py diff --git a/mmcls/models/retrievers/base.py b/mmcls/models/retrievers/base.py index 76d4dff1580..dd5e561a073 100644 --- a/mmcls/models/retrievers/base.py +++ b/mmcls/models/retrievers/base.py @@ -50,7 +50,6 @@ def __init__( init_cfg=init_cfg, data_preprocessor=data_preprocessor) self.prototype = prototype self.prototype_inited = False - self.prototype_vecs = None @abstractmethod def forward(self, diff --git a/mmcls/models/retrievers/image2image.py b/mmcls/models/retrievers/image2image.py index 21091b1ccc4..6e1b5979c96 100644 --- a/mmcls/models/retrievers/image2image.py +++ b/mmcls/models/retrievers/image2image.py @@ -8,6 +8,7 @@ from mmcls.registry import MODELS from mmcls.structures import ClsDataSample +from mmcls.utils import track_on_main_process from .base import BaseRetriever @@ -86,7 +87,6 @@ def __init__(self, self.prototype = prototype self.prototype_inited = False - self.prototype_vecs = None self.topk = topk @property @@ -248,7 +248,7 @@ def _get_prototype_vecs_from_dataloader(self): num = len(data_loader.dataset) prototype_vecs = None - for data_batch in data_loader: + for data_batch in track_on_main_process(data_loader): data = self.data_preprocessor(data_batch, False) feat = self(**data) if isinstance(feat, tuple): @@ -278,16 +278,17 @@ def prepare_prototype(self): """ device = next(self.image_encoder.parameters()).device if isinstance(self.prototype, torch.Tensor): - self.prototype_vecs = self.prototype + prototype_vecs = self.prototype elif isinstance(self.prototype, str): - self.prototype_vecs = torch.load(self.prototype) + prototype_vecs = torch.load(self.prototype) elif isinstance(self.prototype, dict): self.prototype = Runner.build_dataloader(self.prototype) if isinstance(self.prototype, DataLoader): - self.prototype_vecs = self._get_prototype_vecs_from_dataloader() + prototype_vecs = self._get_prototype_vecs_from_dataloader() - self.prototype_vecs = self.prototype_vecs.to(device) + self.register_buffer( + 'prototype_vecs', prototype_vecs.to(device), persistent=False) self.prototype_inited = True def dump_prototype(self, path): diff --git a/mmcls/utils/__init__.py b/mmcls/utils/__init__.py index 04d609cad63..236d91e9f06 100644 --- a/mmcls/utils/__init__.py +++ b/mmcls/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env +from .progress import track_on_main_process from .setup_env import register_all_modules -__all__ = ['collect_env', 'register_all_modules'] +__all__ = ['collect_env', 'register_all_modules', 'track_on_main_process'] diff --git a/mmcls/utils/progress.py b/mmcls/utils/progress.py new file mode 100644 index 00000000000..c200944f7d2 --- /dev/null +++ b/mmcls/utils/progress.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.dist as dist +import rich.progress as progress + + +def track_on_main_process(sequence, *args, **kwargs): + if not dist.is_main_process(): + return sequence + + yield from progress.track(sequence, *args, **kwargs)