Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/image shape filter #74

Merged
merged 5 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion configs/config_all.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Process config example including:
# - all global arguments
# - all ops and their default arguments
# - all ops and their arguments

# global parameters
project_name: 'all' # project name for distinguish your configs
Expand Down Expand Up @@ -126,6 +126,12 @@ process:
min_ratio: 0.333 # the min aspect ratio of filter range
max_ratio: 3.0 # the max aspect ratio of filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_shape_filter: # filter samples according to the widths and heights of images in them
min_width: 200 # the min width of width filter range
max_width: 5000 # the max width of width filter range
min_height: 200 # the min height of height filter range
max_height: 5000 # the max height of height filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_size_filter: # filter samples according to the size of images (in bytes) within them
min_size: "0" # the min size of filter range
max_size: "1TB" # the max size of filter range
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, clip_similarity_filter,
flagged_words_filter, image_aspect_ratio_filter,
image_size_filter, language_id_score_filter,
image_shape_filter, image_size_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
Expand Down
107 changes: 107 additions & 0 deletions data_juicer/ops/filter/image_shape_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys

import numpy as np
from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_image

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES


@OPERATORS.register_module('image_shape_filter')
@LOADED_IMAGES.register_module('image_shape_filter')
class ImageShapeFilter(Filter):
"""Filter to keep samples with image shape (w, h) within specific ranges.
"""

def __init__(self,
min_width: PositiveInt = 1,
max_width: PositiveInt = sys.maxsize,
min_height: PositiveInt = 1,
max_height: PositiveInt = sys.maxsize,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param min_width: The min width to keep samples.
:param max_width: The max width to keep samples.
:param min_height: The min height to keep samples.
:param max_height: The max height to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_width = min_width
self.max_width = max_width
self.min_height = min_height
self.max_height = max_height
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_width in sample[Fields.stats] \
and StatsKeys.image_height in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.image_width] = np.array(
[], dtype=np.int64)
sample[Fields.stats][StatsKeys.image_height] = np.array(
[], dtype=np.int64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if context and loaded_image_key in sample[Fields.context]:
# load from context
images[loaded_image_key] = sample[
Fields.context][loaded_image_key]
else:
if loaded_image_key not in images:
# avoid load the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image
if context:
# store the image data into context
sample[Fields.context][loaded_image_key] = image

# get width and height for each image
whs = {key: (images[key].width, images[key].height) for key in images}
sample[Fields.stats][StatsKeys.image_width] = [
whs[key][0] for key in loaded_image_keys
]
sample[Fields.stats][StatsKeys.image_height] = [
whs[key][1] for key in loaded_image_keys
]
return sample

def process(self, sample):
ws = sample[Fields.stats][StatsKeys.image_width]
hs = sample[Fields.stats][StatsKeys.image_height]
if len(ws) <= 0:
return True
keep_bools = np.array([
self.min_width <= w <= self.max_width
and self.min_height <= h <= self.max_height
for w, h in zip(ws, hs)
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 1 addition & 1 deletion data_juicer/utils/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(self, compressor_format: str = 'zstd'):
self.compressor_extension = '.' + compressor_format
self.compress_manager = CompressManager(
compressor_format=compressor_format)
self.pattern = re.compile('_\d{5}_of_') # noqa W605
self.pattern = re.compile(r'_\d{5}_of_')

def _get_raw_filename(self, filename: Union[Path, str]):
"""
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class StatsKeys(object):

# image
aspect_ratios = 'aspect_ratios'
image_width = 'image_width'
image_height = 'image_height'
image_sizes = 'image_sizes'

# multimodal
Expand Down
3 changes: 2 additions & 1 deletion demos/overview_scan/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
|-----------------------------------|:------:|-------------------------------------------------|
| Formatter | 7 | Discovers, loads, and canonicalizes source data |
| Mapper | 21 | Edits and transforms samples |
| Filter | 19 | Filters out low-quality samples |
| Filter | 20 | Filters out low-quality samples |
| Deduplicator | 4 | Detects and removes duplicate samples |
| Selector | 2 | Selects top samples based on ranking |
'''
Expand Down Expand Up @@ -143,6 +143,7 @@
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
| [ Filter ]( #filter ) | 19 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 20 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |

Expand Down Expand Up @@ -80,6 +80,7 @@ All the specific operators are listed below, each featured with several capabili
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| image_shape_filter | Image | - | Keeps samples contains images with widths and heights within specific ranges |
| image_size_filter | Image | - | Keeps samples contains images whose size in bytes are within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 19 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 20 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -77,6 +77,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 |
| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 |
| image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 |
| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 |
| maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 |
Expand Down
1 change: 1 addition & 0 deletions environments/minimal_requires.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fsspec==2023.3.0
pyarrow<=13.0.0
pandas==2.0.0
datasets==2.11.0
loguru
Expand Down
127 changes: 127 additions & 0 deletions tests/ops/filter/test_image_shape_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import unittest
import numpy as np
import PIL.Image

from datasets import Dataset, Image

from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter
from data_juicer.utils.constant import Fields


class ImageShapeFilterTest(unittest.TestCase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..', 'data')
img1_path = os.path.join(data_path, 'img1.png')
img2_path = os.path.join(data_path, 'img2.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')

def _run_image_shape_filter(self,
dataset: Dataset,
target_list,
op):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=[op.image_key])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_filter1(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img2_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400)
self._run_image_shape_filter(dataset, tgt_list, op)

def test_filter2(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path]
}, {
'images': [self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(max_width=500,
max_height=500)
self._run_image_shape_filter(dataset, tgt_list, op)

def test_filter3(self):

ds_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path]
}, {
'images': [self.img2_path]
}, {
'images': [self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter()
self._run_image_shape_filter(dataset, tgt_list, op)

def test_any(self):

ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400,
any_or_all='any')
self._run_image_shape_filter(dataset, tgt_list, op)

def test_all(self):

ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = []
dataset = Dataset.from_list(ds_list)
op = ImageShapeFilter(min_width=400,
min_height=400,
any_or_all='all')
self._run_image_shape_filter(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()