Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when using ClassBalancedDataset #555

Merged
merged 8 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def get_cat_ids(self, idx):
idx (int): Index of data.

Returns:
int: Image category of specified index.
cat_ids (np.ndarray): Image category of specified index.
Ezra-Yu marked this conversation as resolved.
Show resolved Hide resolved
"""

return self.data_infos[idx]['gt_label'].astype(np.int)
return self.data_infos[idx]['gt_label'].astype(np.int64)

def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
Expand Down
4 changes: 2 additions & 2 deletions mmcls/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_repeat_factors(self, dataset, repeat_thr):
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
cat_ids = set(self.dataset.get_cat_ids(idx).flatten())
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
Expand All @@ -156,7 +156,7 @@ def _get_repeat_factors(self, dataset, repeat_thr):
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
cat_ids = set(self.dataset.get_cat_ids(idx).flatten())
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
Expand Down
12 changes: 12 additions & 0 deletions mmcls/datasets/imagenet21k.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def __init__(self,
super(ImageNet21k, self).__init__(data_prefix, pipeline, classes,
ann_file, test_mode)

def get_cat_ids(self, idx):
"""Get category id by index.

Args:
idx (int): Index of data.

Returns:
cat_ids (np.ndarray): Image category of specified index.
"""

return np.array(self.data_infos[idx].gt_label, dtype=np.int64)

def prepare_data(self, idx):
info = self.data_infos[idx]
results = {
Expand Down
4 changes: 2 additions & 2 deletions mmcls/datasets/multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def get_cat_ids(self, idx):
idx (int): Index of data.

Returns:
np.ndarray: Image categories of specified index.
cat_ids (np.ndarray): Image categories of specified index.
"""
gt_labels = self.data_infos[idx]['gt_label']
cat_ids = np.where(gt_labels == 1)[0]
cat_ids = np.where(gt_labels == 1)[0].astype(np.int64)
return cat_ids

def evaluate(self,
Expand Down
18 changes: 17 additions & 1 deletion tests/test_data/test_datasets/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name)
load_annotations_f = dataset_class.load_annotations
dataset_class.load_annotations = MagicMock()
ann = [
dict(
img_prefix='',
img_info=dict(),
gt_label=np.array(0, dtype=np.int64))
]
dataset_class.load_annotations = MagicMock(return_value=ann)

original_classes = dataset_class.CLASSES

Expand All @@ -44,6 +50,11 @@ def test_datasets_override_default(dataset_name):
test_mode=True)
assert dataset.CLASSES == ('bus', 'car')

# Test get_cat_ids
if dataset_name != 'ImageNet21k':
assert isinstance(dataset.get_cat_ids(0), np.ndarray)
assert np.issubdtype(dataset.get_cat_ids(0).dtype, np.int64)

# Test setting classes as a list
dataset = dataset_class(
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
Expand Down Expand Up @@ -280,6 +291,11 @@ def test_dataset_imagenet21k():
assert 'img_info' in dataset[0]
assert 'gt_label' in dataset[0]

# Test get_cat_ids
assert isinstance(dataset.get_cat_ids(0),
np.ndarray), type(dataset.get_cat_ids(0))
assert np.issubdtype(dataset.get_cat_ids(0).dtype, np.int64)

# test with recursion_subdir is False
dataset_cfg = base_dataset_cfg.copy()
dataset_cfg['recursion_subdir'] = False
Expand Down
52 changes: 40 additions & 12 deletions tests/test_data/test_datasets/test_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
RepeatDataset)


@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_dataset(length):
def construct_toy_multi_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [
np.random.randint(0, 80, num).tolist()
np.random.randint(0, 80, num)
for num in np.random.randint(1, 20, length)
]
dataset.data_infos = MagicMock()
Expand All @@ -25,38 +26,65 @@ def construct_toy_dataset(length):
return dataset, cat_ids_list


def test_concat_dataset():
@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_sigle_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [
np.array(np.random.randint(0, 80), dtype=np.int64)
for _ in range(length)
]
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
return dataset, cat_ids_list


@pytest.mark.parametrize(
'construct_dataset',
['construct_toy_multi_label_dataset', 'construct_toy_sigle_label_dataset'])
def test_concat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
dataset_b, cat_ids_list_b = construct_toy_dataset(20)

concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert np.array_equal(concat_dataset.get_cat_ids(5), cat_ids_list_a[5])
assert np.array_equal(concat_dataset.get_cat_ids(25), cat_ids_list_b[15])
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
assert concat_dataset.CLASSES == BaseDataset.CLASSES


def test_repeat_dataset():
@pytest.mark.parametrize(
'construct_dataset',
['construct_toy_multi_label_dataset', 'construct_toy_sigle_label_dataset'])
def test_repeat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)
repeat_dataset = RepeatDataset(dataset, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
assert np.array_equal(repeat_dataset.get_cat_ids(5), cat_ids_list[5])
assert np.array_equal(repeat_dataset.get_cat_ids(15), cat_ids_list[5])
assert np.array_equal(repeat_dataset.get_cat_ids(27), cat_ids_list[7])
assert len(repeat_dataset) == 10 * len(dataset)
assert repeat_dataset.CLASSES == BaseDataset.CLASSES


def test_class_balanced_dataset():
@pytest.mark.parametrize(
'construct_dataset',
['construct_toy_multi_label_dataset', 'construct_toy_sigle_label_dataset'])
def test_class_balanced_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)

category_freq = defaultdict(int)
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
cat_ids = set(cat_ids.flatten())
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
Expand All @@ -72,7 +100,7 @@ def test_class_balanced_dataset():

repeat_factors = []
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
cat_ids = set(cat_ids.flatten())
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
Expand Down