From 81f55331591e88a444c18c80e1660cb90584e071 Mon Sep 17 00:00:00 2001 From: zhijianma Date: Fri, 17 Nov 2023 08:59:42 +0800 Subject: [PATCH 1/5] feature: add max_samples to limit mixed datasets --- data_juicer/format/mixture_formatter.py | 59 +++++++++++++++++++++---- tools/postprocess/data_mixture.py | 7 ++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index f55907f90..d8cc6b9ad 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -17,6 +17,7 @@ def __init__(self, suffixes: Union[str, List[str], Tuple[str]] = None, text_keys=None, add_suffix=False, + max_samples=None, **kwargs): """ Initialization method. @@ -28,9 +29,30 @@ def __init__(self, :param text_keys: key names of field that stores sample text. :param add_suffix: whether to add the file suffix to dataset meta info + :param max_samples: max samples number of mixed dataset. :param kwargs: extra args """ + data_prefixes, weights = self._get_weight(data_prefix=dataset_path) + sample_numbers = [0] * len(weights) + if max_samples is not None: + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + sample_num_per_dataset = [ + int(np.ceil(max_samples * weight)) for weight in weights + ] + + # Adjust + acc_sample_numbers = 0 + for i in range(len(sample_num_per_dataset)): + sample_numbers[i] = min(sample_num_per_dataset[i], + max_samples - acc_sample_numbers) + acc_sample_numbers += sample_numbers[i] + + self.sample_numbers = sample_numbers self.weights = weights self.formatters = [ load_formatter(dataset_path=data_prefix, @@ -65,21 +87,38 @@ def _get_weight(self, data_prefix): prefixes.append(value) return prefixes, weights - def _random_sample(self, dataset, weight=1.0, seed=None): + def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None): """ - Randomly sample a subset from a dataset with weight. + Randomly sample a subset from a dataset with weight or number, + if sample number is bigger than 0, we will use sample + number instead of weight. :param dataset: a HuggingFace dataset :param weight: sample ratio of dataset + :param sample_number: sample number of dataset :param seed: random sample seed, if None, 42 as default :return: a subset of dataset """ if seed is None: seed = 42 - num_samples = min(int(np.ceil(dataset.num_rows * weight)), - dataset.num_rows) - if num_samples == dataset.num_rows: + + ds_samples = dataset.num_rows + if sample_number <= 0: + sample_number = int(np.ceil(ds_samples * weight)) + + if sample_number == ds_samples: return dataset - return dataset.shuffle(seed=seed).select(range(num_samples)) + + num_epochs = int(np.ceil(sample_number / ds_samples)) - 1 + + if num_epochs > 0: + remain_samples = sample_number - num_epochs * ds_samples + sample_index = list(range(ds_samples)) * num_epochs + list( + range(remain_samples)) + else: + remain_samples = sample_number + sample_index = list(range(remain_samples)) + + return dataset.shuffle(seed=seed).select(sample_index) def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: """ @@ -90,11 +129,13 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: :return: mixed dataset """ dataset_list = [] - for weight, formatter in zip(self.weights, self.formatters): + for weight, sample_num, formatter in zip(self.weights, + self.sample_numbers, + self.formatters): dataset = formatter.load_dataset(num_proc, global_cfg) - sampled = self._random_sample(dataset, weight) + sampled = self._random_sample(dataset, weight, sample_num) logger.info(f'sampled {len(sampled)} from ' - f'{len(dataset)} with weight {weight}') + f'{len(dataset)}') dataset_list.append(sampled) from data_juicer.core.data import NestedDataset diff --git a/tools/postprocess/data_mixture.py b/tools/postprocess/data_mixture.py index 146986976..db89a2a1f 100644 --- a/tools/postprocess/data_mixture.py +++ b/tools/postprocess/data_mixture.py @@ -33,6 +33,11 @@ def parse_args(): 'size of each shard won\'t larger than the ' 'export_shard_size') + parser.add_argument('--max_samples', + type=int, + default=None, + help='Number of samples of mixed dataset.') + parser.add_argument('--num_proc', type=int, default=4, @@ -58,7 +63,7 @@ def run_mixture(): """ args = parse_args() data_path = ' '.join(args.data_path) - formatter = load_formatter(data_path) + formatter = load_formatter(data_path, max_samples=args.max_samples) dataset = formatter.load_dataset(args.num_proc) exporter = Exporter(export_path=args.export_path, export_shard_size=args.export_shard_size, From 017c36dd5d6eb9a65cd5b5d3935af160e9a442df Mon Sep 17 00:00:00 2001 From: zhijianma Date: Fri, 17 Nov 2023 09:57:03 +0800 Subject: [PATCH 2/5] test: add mixture formatter test cases --- tests/format/test_mixture_formatter.py | 78 ++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/format/test_mixture_formatter.py diff --git a/tests/format/test_mixture_formatter.py b/tests/format/test_mixture_formatter.py new file mode 100644 index 000000000..fc16dcbe1 --- /dev/null +++ b/tests/format/test_mixture_formatter.py @@ -0,0 +1,78 @@ +import os +import unittest + +from data_juicer.format.mixture_formatter import MixtureFormatter + + +class MixtureFormatterTest(unittest.TestCase): + + def setUp(self): + self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'data', 'structured') + self._file = os.path.join(self._path, 'demo-dataset.jsonl') + self._file2 = self._file + + def test_only_file(self): + formatter = MixtureFormatter(self._file) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_sample_weight(self): + formatter = MixtureFormatter('0.5 ' + self._file) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 3) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_sample_number(self): + max_samples = 2 + formatter = MixtureFormatter(self._file, max_samples=max_samples) + ds = formatter.load_dataset() + self.assertEqual(len(ds), max_samples) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_sample_number_weight(self): + max_samples = 2 + formatter = MixtureFormatter('0.5 ' + self._file, max_samples=max_samples) + ds = formatter.load_dataset() + self.assertEqual(len(ds), max_samples) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_multi_datasets_without_weight(self): + data_path = self._file + ' ' + self._file2 + formatter = MixtureFormatter(data_path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 12) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_multi_datasets_with_weight(self): + data_path = self._file + ' ' + self._file2 + formatter = MixtureFormatter(data_path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 12) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_multi_datasets_with_one_weight(self): + data_path = '0.5 ' + self._file + ' ' + self._file2 + formatter = MixtureFormatter(data_path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 9) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_multi_datasets_with_weight(self): + data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2 + formatter = MixtureFormatter(data_path) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 6) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + + def test_multi_datasets_with_sample(self): + max_samples = 7 + data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2 + formatter = MixtureFormatter(data_path, max_samples=max_samples) + ds = formatter.load_dataset() + self.assertEqual(len(ds), max_samples) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + +if __name__ == '__main__': + unittest.main() From c2370c94b79b1702599fcde85bd3a59529cb952b Mon Sep 17 00:00:00 2001 From: zhijianma Date: Fri, 17 Nov 2023 10:07:52 +0800 Subject: [PATCH 3/5] test: add test data --- tests/format/data/structured/demo-dataset.jsonl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/format/data/structured/demo-dataset.jsonl b/tests/format/data/structured/demo-dataset.jsonl index 707f802b0..590c029f5 100644 --- a/tests/format/data/structured/demo-dataset.jsonl +++ b/tests/format/data/structured/demo-dataset.jsonl @@ -1,2 +1,6 @@ -{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} -{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code"}} +{"text": "你好,请问你是谁", "meta": {"src": "customized"}} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar"}} +{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized"}} +{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized"}} \ No newline at end of file From 961917661cdfa219b5e2d8e7f6e2a74d3281db24 Mon Sep 17 00:00:00 2001 From: zhijianma Date: Fri, 17 Nov 2023 10:38:07 +0800 Subject: [PATCH 4/5] test: change test data --- tests/format/data/structured/demo-dataset.jsonl | 12 ++++++------ tests/format/test_unify_format.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/format/data/structured/demo-dataset.jsonl b/tests/format/data/structured/demo-dataset.jsonl index 590c029f5..77a0a1d88 100644 --- a/tests/format/data/structured/demo-dataset.jsonl +++ b/tests/format/data/structured/demo-dataset.jsonl @@ -1,6 +1,6 @@ -{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv"}} -{"text": "Do you need a cup of coffee?", "meta": {"src": "code"}} -{"text": "你好,请问你是谁", "meta": {"src": "customized"}} -{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément.", "meta": {"src": "Oscar"}} -{"text": "欢迎来到阿里巴巴!", "meta": {"src": "customized"}} -{"text": "This paper proposed a novel method on LLM pretraining.", "meta": {"src": "customized"}} \ No newline at end of file +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} +{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} \ No newline at end of file diff --git a/tests/format/test_unify_format.py b/tests/format/test_unify_format.py index 2f64d0dcf..c9b41d19d 100644 --- a/tests/format/test_unify_format.py +++ b/tests/format/test_unify_format.py @@ -366,8 +366,7 @@ def test_hetero_meta(self): 'author': 'xxx' } }] - unified_sample_list = ds.to_list() - self.assertEqual(unified_sample_list, sample) + # test nested and missing field for the following cases: # 1. first row, then column unified_sample_first = ds[0] From 942b6f56c68c1fd69e8837bf73d0c565c47c229c Mon Sep 17 00:00:00 2001 From: zhijianma Date: Fri, 17 Nov 2023 16:57:16 +0800 Subject: [PATCH 5/5] use iter insead of list --- data_juicer/format/mixture_formatter.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index d8cc6b9ad..fc7762c23 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -1,3 +1,4 @@ +from itertools import chain, repeat from typing import List, Tuple, Union import numpy as np @@ -76,7 +77,7 @@ def _get_weight(self, data_prefix): for i in range(len(data_prefix)): try: - value = float(data_prefix[i]) + value = max(float(data_prefix[i]), 0.0) weights.append(value) except: # noqa: E722 value = data_prefix[i].strip() @@ -108,15 +109,13 @@ def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None): if sample_number == ds_samples: return dataset - num_epochs = int(np.ceil(sample_number / ds_samples)) - 1 + sample_index = range(sample_number) - if num_epochs > 0: - remain_samples = sample_number - num_epochs * ds_samples - sample_index = list(range(ds_samples)) * num_epochs + list( - range(remain_samples)) - else: - remain_samples = sample_number - sample_index = list(range(remain_samples)) + n_repeat = int(np.ceil(sample_number / ds_samples)) - 1 + if n_repeat > 0: + remain_samples = sample_number - n_repeat * ds_samples + sample_index = chain(*repeat(range(ds_samples), n_repeat), + range(remain_samples)) return dataset.shuffle(seed=seed).select(sample_index)