diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 643db46045..168f13417e 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -105,8 +105,7 @@ def test_dataset_wrapper(): img_scale = (60, 60) pipeline = [ - # dict(type='Mosaic', img_scale=img_scale, pad_val=255), - # need to merge mosaic + dict(type='RandomMosaic', prob=1, img_scale=img_scale), dict(type='RandomFlip', prob=0.5), dict(type='Resize', img_scale=img_scale, keep_ratio=False), ] @@ -130,14 +129,8 @@ def test_dataset_wrapper(): classes=classes, palette=palette) len_a = 2 - cat_ids_list_a = [ - np.random.randint(0, 80, num).tolist() - for num in np.random.randint(1, 20, len_a) - ] - dataset_a.data_infos = MagicMock() - dataset_a.data_infos.__len__.return_value = len_a - dataset_a.get_cat_ids = MagicMock( - side_effect=lambda idx: cat_ids_list_a[idx]) + dataset_a.img_infos = MagicMock() + dataset_a.img_infos.__len__.return_value = len_a multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) assert len(multi_image_mix_dataset) == len(dataset_a)