From baddb1442ce51e529a4533405653186c7a373149 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Wed, 15 Nov 2023 18:03:13 +0800 Subject: [PATCH 1/2] * fix bug in nlpaug_en_mapper: nlpaug could generate an indefinite number of augmented samples --- data_juicer/ops/mapper/nlpaug_en_mapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 6a5148c7b..8509c1ba0 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -125,7 +125,7 @@ def process(self, samples): if key == self.text_key: res_samples[self.text_key] += aug_texts else: - res_samples[key] += res_samples[key] * self.aug_num + res_samples[key] += res_samples[key] * len(aug_texts) else: # apply each aug method to generate several augmented texts for aug_method in self.aug: @@ -134,6 +134,6 @@ def process(self, samples): # add other replicate fields for key in res_samples: if key != self.text_key: - res_samples[key] += res_samples[key] * self.aug_num \ - * len(self.aug) + res_samples[key] = res_samples[key] * \ + len(res_samples[self.text_key]) return res_samples From 35c5a8c1fe4e7070be5a41e5cf340475fb1d93b9 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Wed, 15 Nov 2023 19:37:25 +0800 Subject: [PATCH 2/2] * reconstruct the code structure of two aug mappers --- data_juicer/ops/mapper/nlpaug_en_mapper.py | 24 ++++++++++------------ data_juicer/ops/mapper/nlpcda_zh_mapper.py | 24 ++++++++++------------ 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 8509c1ba0..ae40b461c 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -118,22 +118,20 @@ def process(self, samples): texts_to_aug = samples[self.text_key][0] # batch_size = 1 res_samples = deepcopy(samples) + # get augmented texts if self.sequential: aug_texts = self.aug.augment(texts_to_aug, n=self.aug_num) - # add augmented samples to the batch with other replicate fields - for key in res_samples: - if key == self.text_key: - res_samples[self.text_key] += aug_texts - else: - res_samples[key] += res_samples[key] * len(aug_texts) else: # apply each aug method to generate several augmented texts + aug_texts = [] for aug_method in self.aug: - aug_texts = aug_method.augment(texts_to_aug, n=self.aug_num) - res_samples[self.text_key] += aug_texts - # add other replicate fields - for key in res_samples: - if key != self.text_key: - res_samples[key] = res_samples[key] * \ - len(res_samples[self.text_key]) + aug_texts += aug_method.augment(texts_to_aug, n=self.aug_num) + + # add augmented samples to the batch with other replicate fields + res_samples[self.text_key] += aug_texts + # add other replicate fields + for key in res_samples: + if key != self.text_key: + res_samples[key] = res_samples[key] * \ + len(res_samples[self.text_key]) return res_samples diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index 51cf50e49..3f10b2f58 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -125,6 +125,7 @@ def process(self, samples): texts_to_aug = samples[self.text_key] res_samples = deepcopy(samples) + # get augmented texts if self.sequential: aug_texts = texts_to_aug for aug_method in self.aug_pipeline: @@ -136,20 +137,17 @@ def process(self, samples): aug_texts = results[:] if len(aug_texts) == 1 and aug_texts[0] == texts_to_aug[0]: aug_texts = [] - # add augmented samples to the batch with other replicate fields - for key in res_samples: - if key == self.text_key: - res_samples[self.text_key] += aug_texts - else: - res_samples[key] += res_samples[key] * len(aug_texts) else: # apply each aug method to generate several augmented texts + aug_texts = [] for aug_method in self.aug_pipeline: - aug_texts = aug_method.replace(texts_to_aug[0])[1:] - res_samples[self.text_key] += aug_texts - # add other replicate fields - for key in res_samples: - if key != self.text_key: - res_samples[key] = res_samples[key] * \ - len(res_samples[self.text_key]) + aug_texts += aug_method.replace(texts_to_aug[0])[1:] + + # add augmented samples to the batch with other replicate fields + res_samples[self.text_key] += aug_texts + # add other replicate fields + for key in res_samples: + if key != self.text_key: + res_samples[key] = res_samples[key] * \ + len(res_samples[self.text_key]) return res_samples