Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezra-Yu committed Oct 10, 2022
1 parent d0b43fd commit 6a69ef2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions mmcls/models/retrievers/image2image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from typing import Callable, List, Optional, Union

import mmengine.dist as dist
Expand Down Expand Up @@ -244,7 +243,7 @@ def _get_predictions(self, result, data_samples):
return data_samples

def _get_prototype_vecs_from_dataloader(self):
"""get prototype_vecs from dataloader"""
"""get prototype_vecs from dataloader."""
data_loader = self.prototype
num = len(data_loader.dataset)

Expand All @@ -261,9 +260,9 @@ def _get_prototype_vecs_from_dataloader(self):
for i, data_sample in enumerate(data_batch['data_samples']):
sample_idx = data_sample.get('sample_idx')
prototype_vecs[sample_idx] = feat[i]

assert prototype_vecs is not None
dist.all_reduce(prototype_vecs)
dist.all_reduce(prototype_vecs)
return prototype_vecs

@torch.no_grad()
Expand All @@ -287,7 +286,7 @@ def prepare_prototype(self):

if isinstance(self.prototype, DataLoader):
self.prototype_vecs = self._get_prototype_vecs_from_dataloader()

self.prototype_vecs = self.prototype_vecs.to(device)
self.prototype_inited = True

Expand Down

0 comments on commit 6a69ef2

Please sign in to comment.