Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Normalization #120

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
args: ['--style', '{column_limit: 79}']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this, does it affect historical codes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, yapf would be conflicting with isort in certain case.

exclude: data_juicer/ops/common/special_characters.py
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
language_id_score_filter, maximum_line_length_filter,
perplexity_filter, special_characters_filter,
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_length_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, word_num_filter, word_repetition_filter)
66 changes: 66 additions & 0 deletions data_juicer/ops/filter/text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_action_filter'


@OPERATORS.register_module(OP_NAME)
class TextActionFilter(Filter):
"""
Filter to keep texts those contain actions in the text.
"""

def __init__(self,
lang: str = 'en',
min_action_num: int = 1,
*args,
**kwargs):
"""
Initialization method.

:param lang: language of the text in the samples. 'en' for detection of
actions in English an'zh' for detection of actions in Chinese.
:param mini_action_num: The min action number in the filtering. samples
will be filtered if their action number in the text is below this
parameter.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in action detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.action_poss = ['VERB']
self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN']
self.min_action_num = min_action_num

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_action in sample[Fields.stats]:
return sample

text = remove_special_tokens(sample[self.text_key])

# process text via spacy and count the actions in text
model = get_model(self.model_key)
doc = model(text)
num_action = 0
for token in doc:
if token.pos_ in self.action_poss \
and token.tag_ in self.action_tags:
num_action += 1
sample[Fields.stats][StatsKeys.num_action] = num_action

return sample

def process(self, sample):
num_action = sample[Fields.stats][StatsKeys.num_action]
if self.min_action_num <= num_action:
return True
else:
return False
103 changes: 103 additions & 0 deletions data_juicer/ops/filter/text_entity_dependency_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

OP_NAME = 'text_entity_dependency_filter'


@OPERATORS.register_module(OP_NAME)
class TextEntityDependencyFilter(Filter):
"""
Identify the entities in the text which are independent with other token,
and filter them. The text containing no entities will be omitted.
"""

def __init__(self,
lang: str = 'en',
min_dependency_num: int = 1,
any_or_all: str = 'all',
*args,
**kwargs):
"""
Initialization method.

:param lang: language of the text in the samples. 'en' for detection of
actions in English an'zh' for detection of actions in Chinese.
:param mini_dependency_num: The min token number in the filtering.
Objects is independent if their number of edges in the dependency
tree is below this parameter.
:param any_or_all: keep this sample with 'any' or 'all' strategy.
'any': keep this sample if any objet is dependent. 'all': keep this
sample only if all images are dependent.
"""
super().__init__(*args, **kwargs)

if lang not in ['en', 'zh']:
raise ValueError(
f'Language [{lang}] is not supported in action detection.'
f'Can only be one of ["en", "zh"].')
self.lang = lang
self.model_key = prepare_model(model_type='spacy', lang=lang)
self.entity_poss = ['NOUN', 'PROPN', 'PRON']
self.entity_tags = ['NN', 'NR', 'PN', 'NNS', 'NNP', 'NNPS', 'PRP']
self.min_dependency_num = min_dependency_num
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_dependency_edges in sample[Fields.stats]:
return sample

text = remove_special_tokens(sample[self.text_key])

# identify entities
model = get_model(self.model_key)
doc = model(text)
entity_to_dependency_nums = {}
for token in doc:
if token.pos_ in self.entity_poss \
and token.tag_ in self.entity_tags:
entity_to_dependency_nums[token] = 0

# count the edges of each entity in dependency tree
for obj in entity_to_dependency_nums:
if obj.dep_ != 'ROOT':
entity_to_dependency_nums[obj] += 1
for token in doc:
# the punctation mark such as ',', '.'
if token.pos_ == 'PUNCT':
continue

if token.head in entity_to_dependency_nums.keys(
) and token.dep_ != 'ROOT':
entity_to_dependency_nums[token.head] += 1

sample[Fields.stats][StatsKeys.num_dependency_edges] = [
n for _, n in entity_to_dependency_nums.items()
]

return sample

def process(self, sample):
num_dependency_edges = sample[Fields.stats][
StatsKeys.num_dependency_edges]
keep_bools = np.array([
self.min_dependency_num <= num_edge
for num_edge in num_dependency_edges
])
# omit the samples without entity
if len(keep_bools) <= 0:
return False

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class StatsKeys(object):
special_char_ratio = 'special_char_ratio'
stopwords_ratio = 'stopwords_ratio'
text_len = 'text_len'
num_action = 'num_action'
num_dependency_edges = 'num_dependency_edges'
num_token = 'num_token'
num_words = 'num_words'
word_rep_ratio = 'word_rep_ratio'
Expand Down
1 change: 1 addition & 0 deletions environments/minimal_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ plotly
python-docx
streamlit
spacy==3.5.0
spacy-pkuseg==0.0.32
multiprocess==0.70.12
dill==0.3.4
114 changes: 114 additions & 0 deletions tests/ops/filter/test_text_action_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
import os

from datasets import Dataset

from data_juicer.ops.filter.text_action_filter import TextActionFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens


class TextActionFilterTest(unittest.TestCase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')

cat_path = os.path.join(data_path, 'cat.jpg')
img3_path = os.path.join(data_path, 'img3.jpg')

def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=column_names)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_en_text_case(self):

ds_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}, {
'text': 'Tom 在打篮球'
}, {
'text': 'a v s e c s f e f g a a a '
}, {
'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
}, {
'text': 'that is a green tree'
}]
tgt_list = [{
'text': 'Tom is playing piano.'
}, {
'text': 'Tom plays piano.'
}, {
'text': 'Tom played piano.'
},{
'text': 'I play piano.'
}, {
'text': 'to play piano.'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='en')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_zh_text_case(self):

ds_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom is playing 篮球'
}, {
'text': '上上下下左左右右'
}, {
'text': 'Tom在打篮球'
}, {
'text': '我有一只猫,它是一只猫'
}]
tgt_list = [{
'text': '小明在 弹奏钢琴'
}, {
'text': 'Tom在打篮球'
}]
dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text'])

def test_image_text_case(self):
ds_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}小猫咪',
'images': [self.cat_path]
}, {
'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}',
'images': [self.img3_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]
tgt_list = [{
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}',
'images': [self.cat_path]
}, {
'text': f'雨中行走的女人背影',
'images': [self.img3_path]
}]

dataset = Dataset.from_list(ds_list)
op = TextActionFilter(lang='zh')
self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images'])

if __name__ == '__main__':
unittest.main()
Loading