-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Code Normalization #120
Closed
Closed
Code Normalization #120
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import remove_special_tokens | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Filter | ||
|
||
OP_NAME = 'text_action_filter' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class TextActionFilter(Filter): | ||
""" | ||
Filter to keep texts those contain actions in the text. | ||
""" | ||
|
||
def __init__(self, | ||
lang: str = 'en', | ||
min_action_num: int = 1, | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
|
||
:param lang: language of the text in the samples. 'en' for detection of | ||
actions in English an'zh' for detection of actions in Chinese. | ||
:param mini_action_num: The min action number in the filtering. samples | ||
will be filtered if their action number in the text is below this | ||
parameter. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
|
||
if lang not in ['en', 'zh']: | ||
raise ValueError( | ||
f'Language [{lang}] is not supported in action detection.' | ||
f'Can only be one of ["en", "zh"].') | ||
self.lang = lang | ||
self.model_key = prepare_model(model_type='spacy', lang=lang) | ||
self.action_poss = ['VERB'] | ||
self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN'] | ||
self.min_action_num = min_action_num | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.num_action in sample[Fields.stats]: | ||
return sample | ||
|
||
text = remove_special_tokens(sample[self.text_key]) | ||
|
||
# process text via spacy and count the actions in text | ||
model = get_model(self.model_key) | ||
doc = model(text) | ||
num_action = 0 | ||
for token in doc: | ||
if token.pos_ in self.action_poss \ | ||
and token.tag_ in self.action_tags: | ||
num_action += 1 | ||
sample[Fields.stats][StatsKeys.num_action] = num_action | ||
|
||
return sample | ||
|
||
def process(self, sample): | ||
num_action = sample[Fields.stats][StatsKeys.num_action] | ||
if self.min_action_num <= num_action: | ||
return True | ||
else: | ||
return False |
103 changes: 103 additions & 0 deletions
103
data_juicer/ops/filter/text_entity_dependency_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import numpy as np | ||
|
||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import remove_special_tokens | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Filter | ||
|
||
OP_NAME = 'text_entity_dependency_filter' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class TextEntityDependencyFilter(Filter): | ||
""" | ||
Identify the entities in the text which are independent with other token, | ||
and filter them. The text containing no entities will be omitted. | ||
""" | ||
|
||
def __init__(self, | ||
lang: str = 'en', | ||
min_dependency_num: int = 1, | ||
any_or_all: str = 'all', | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
|
||
:param lang: language of the text in the samples. 'en' for detection of | ||
actions in English an'zh' for detection of actions in Chinese. | ||
:param mini_dependency_num: The min token number in the filtering. | ||
Objects is independent if their number of edges in the dependency | ||
tree is below this parameter. | ||
:param any_or_all: keep this sample with 'any' or 'all' strategy. | ||
'any': keep this sample if any objet is dependent. 'all': keep this | ||
sample only if all images are dependent. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
|
||
if lang not in ['en', 'zh']: | ||
raise ValueError( | ||
f'Language [{lang}] is not supported in action detection.' | ||
f'Can only be one of ["en", "zh"].') | ||
self.lang = lang | ||
self.model_key = prepare_model(model_type='spacy', lang=lang) | ||
self.entity_poss = ['NOUN', 'PROPN', 'PRON'] | ||
self.entity_tags = ['NN', 'NR', 'PN', 'NNS', 'NNP', 'NNPS', 'PRP'] | ||
self.min_dependency_num = min_dependency_num | ||
if any_or_all not in ['any', 'all']: | ||
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' | ||
f'Can only be one of ["any", "all"].') | ||
self.any = (any_or_all == 'any') | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.num_dependency_edges in sample[Fields.stats]: | ||
return sample | ||
|
||
text = remove_special_tokens(sample[self.text_key]) | ||
|
||
# identify entities | ||
model = get_model(self.model_key) | ||
doc = model(text) | ||
entity_to_dependency_nums = {} | ||
for token in doc: | ||
if token.pos_ in self.entity_poss \ | ||
and token.tag_ in self.entity_tags: | ||
entity_to_dependency_nums[token] = 0 | ||
|
||
# count the edges of each entity in dependency tree | ||
for obj in entity_to_dependency_nums: | ||
if obj.dep_ != 'ROOT': | ||
entity_to_dependency_nums[obj] += 1 | ||
for token in doc: | ||
# the punctation mark such as ',', '.' | ||
if token.pos_ == 'PUNCT': | ||
continue | ||
|
||
if token.head in entity_to_dependency_nums.keys( | ||
) and token.dep_ != 'ROOT': | ||
entity_to_dependency_nums[token.head] += 1 | ||
|
||
sample[Fields.stats][StatsKeys.num_dependency_edges] = [ | ||
n for _, n in entity_to_dependency_nums.items() | ||
] | ||
|
||
return sample | ||
|
||
def process(self, sample): | ||
num_dependency_edges = sample[Fields.stats][ | ||
StatsKeys.num_dependency_edges] | ||
keep_bools = np.array([ | ||
self.min_dependency_num <= num_edge | ||
for num_edge in num_dependency_edges | ||
]) | ||
# omit the samples without entity | ||
if len(keep_bools) <= 0: | ||
return False | ||
|
||
# different strategies | ||
if self.any: | ||
return keep_bools.any() | ||
else: | ||
return keep_bools.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,5 +18,6 @@ plotly | |
python-docx | ||
streamlit | ||
spacy==3.5.0 | ||
spacy-pkuseg==0.0.32 | ||
multiprocess==0.70.12 | ||
dill==0.3.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import unittest | ||
import os | ||
|
||
from datasets import Dataset | ||
|
||
from data_juicer.ops.filter.text_action_filter import TextActionFilter | ||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.mm_utils import SpecialTokens | ||
|
||
|
||
class TextActionFilterTest(unittest.TestCase): | ||
|
||
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', | ||
'data') | ||
|
||
cat_path = os.path.join(data_path, 'cat.jpg') | ||
img3_path = os.path.join(data_path, 'img3.jpg') | ||
|
||
def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names): | ||
if Fields.stats not in dataset.features: | ||
dataset = dataset.add_column(name=Fields.stats, | ||
column=[{}] * dataset.num_rows) | ||
dataset = dataset.map(op.compute_stats) | ||
dataset = dataset.filter(op.process) | ||
dataset = dataset.select_columns(column_names=column_names) | ||
res_list = dataset.to_list() | ||
self.assertEqual(res_list, target_list) | ||
|
||
def test_en_text_case(self): | ||
|
||
ds_list = [{ | ||
'text': 'Tom is playing piano.' | ||
}, { | ||
'text': 'Tom plays piano.' | ||
}, { | ||
'text': 'Tom played piano.' | ||
},{ | ||
'text': 'I play piano.' | ||
}, { | ||
'text': 'to play piano.' | ||
}, { | ||
'text': 'Tom 在打篮球' | ||
}, { | ||
'text': 'a v s e c s f e f g a a a ' | ||
}, { | ||
'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' | ||
}, { | ||
'text': 'that is a green tree' | ||
}] | ||
tgt_list = [{ | ||
'text': 'Tom is playing piano.' | ||
}, { | ||
'text': 'Tom plays piano.' | ||
}, { | ||
'text': 'Tom played piano.' | ||
},{ | ||
'text': 'I play piano.' | ||
}, { | ||
'text': 'to play piano.' | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='en') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text']) | ||
|
||
def test_zh_text_case(self): | ||
|
||
ds_list = [{ | ||
'text': '小明在 弹奏钢琴' | ||
}, { | ||
'text': 'Tom is playing 篮球' | ||
}, { | ||
'text': '上上下下左左右右' | ||
}, { | ||
'text': 'Tom在打篮球' | ||
}, { | ||
'text': '我有一只猫,它是一只猫' | ||
}] | ||
tgt_list = [{ | ||
'text': '小明在 弹奏钢琴' | ||
}, { | ||
'text': 'Tom在打篮球' | ||
}] | ||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='zh') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text']) | ||
|
||
def test_image_text_case(self): | ||
ds_list = [{ | ||
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'{SpecialTokens.image}小猫咪', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}', | ||
'images': [self.img3_path] | ||
}, { | ||
'text': f'雨中行走的女人背影', | ||
'images': [self.img3_path] | ||
}] | ||
tgt_list = [{ | ||
'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}', | ||
'images': [self.cat_path] | ||
}, { | ||
'text': f'雨中行走的女人背影', | ||
'images': [self.img3_path] | ||
}] | ||
|
||
dataset = Dataset.from_list(ds_list) | ||
op = TextActionFilter(lang='zh') | ||
self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images']) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need this, does it affect historical codes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, yapf would be conflicting with isort in certain case.