Skip to content

Commit

Permalink
Feature/mixture (#86)
Browse files Browse the repository at this point in the history
* feature: add max_samples to limit mixed datasets
  • Loading branch information
zhijianma authored Nov 21, 2023
1 parent 62c5fb5 commit 413be3b
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 13 deletions.
60 changes: 50 additions & 10 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import chain, repeat
from typing import List, Tuple, Union

import numpy as np
Expand All @@ -17,6 +18,7 @@ def __init__(self,
suffixes: Union[str, List[str], Tuple[str]] = None,
text_keys=None,
add_suffix=False,
max_samples=None,
**kwargs):
"""
Initialization method.
Expand All @@ -28,9 +30,30 @@ def __init__(self,
:param text_keys: key names of field that stores sample text.
:param add_suffix: whether to add the file suffix to dataset
meta info
:param max_samples: max samples number of mixed dataset.
:param kwargs: extra args
"""

data_prefixes, weights = self._get_weight(data_prefix=dataset_path)
sample_numbers = [0] * len(weights)
if max_samples is not None:
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
sample_num_per_dataset = [
int(np.ceil(max_samples * weight)) for weight in weights
]

# Adjust
acc_sample_numbers = 0
for i in range(len(sample_num_per_dataset)):
sample_numbers[i] = min(sample_num_per_dataset[i],
max_samples - acc_sample_numbers)
acc_sample_numbers += sample_numbers[i]

self.sample_numbers = sample_numbers
self.weights = weights
self.formatters = [
load_formatter(dataset_path=data_prefix,
Expand All @@ -54,7 +77,7 @@ def _get_weight(self, data_prefix):

for i in range(len(data_prefix)):
try:
value = float(data_prefix[i])
value = max(float(data_prefix[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = data_prefix[i].strip()
Expand All @@ -65,21 +88,36 @@ def _get_weight(self, data_prefix):
prefixes.append(value)
return prefixes, weights

def _random_sample(self, dataset, weight=1.0, seed=None):
def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None):
"""
Randomly sample a subset from a dataset with weight.
Randomly sample a subset from a dataset with weight or number,
if sample number is bigger than 0, we will use sample
number instead of weight.
:param dataset: a HuggingFace dataset
:param weight: sample ratio of dataset
:param sample_number: sample number of dataset
:param seed: random sample seed, if None, 42 as default
:return: a subset of dataset
"""
if seed is None:
seed = 42
num_samples = min(int(np.ceil(dataset.num_rows * weight)),
dataset.num_rows)
if num_samples == dataset.num_rows:

ds_samples = dataset.num_rows
if sample_number <= 0:
sample_number = int(np.ceil(ds_samples * weight))

if sample_number == ds_samples:
return dataset
return dataset.shuffle(seed=seed).select(range(num_samples))

sample_index = range(sample_number)

n_repeat = int(np.ceil(sample_number / ds_samples)) - 1
if n_repeat > 0:
remain_samples = sample_number - n_repeat * ds_samples
sample_index = chain(*repeat(range(ds_samples), n_repeat),
range(remain_samples))

return dataset.shuffle(seed=seed).select(sample_index)

def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Expand All @@ -90,11 +128,13 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
:return: mixed dataset
"""
dataset_list = []
for weight, formatter in zip(self.weights, self.formatters):
for weight, sample_num, formatter in zip(self.weights,
self.sample_numbers,
self.formatters):
dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self._random_sample(dataset, weight)
sampled = self._random_sample(dataset, weight, sample_num)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)} with weight {weight}')
f'{len(dataset)}')
dataset_list.append(sampled)

from data_juicer.core.data import NestedDataset
Expand Down
4 changes: 4 additions & 0 deletions tests/format/data/structured/demo-dataset.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
{"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}}
{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}}
78 changes: 78 additions & 0 deletions tests/format/test_mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import unittest

from data_juicer.format.mixture_formatter import MixtureFormatter


class MixtureFormatterTest(unittest.TestCase):

def setUp(self):
self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'data', 'structured')
self._file = os.path.join(self._path, 'demo-dataset.jsonl')
self._file2 = self._file

def test_only_file(self):
formatter = MixtureFormatter(self._file)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 6)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_weight(self):
formatter = MixtureFormatter('0.5 ' + self._file)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 3)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_number(self):
max_samples = 2
formatter = MixtureFormatter(self._file, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_sample_number_weight(self):
max_samples = 2
formatter = MixtureFormatter('0.5 ' + self._file, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_without_weight(self):
data_path = self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 12)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_weight(self):
data_path = self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 12)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_one_weight(self):
data_path = '0.5 ' + self._file + ' ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 9)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_weight(self):
data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2
formatter = MixtureFormatter(data_path)
ds = formatter.load_dataset()
self.assertEqual(len(ds), 6)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

def test_multi_datasets_with_sample(self):
max_samples = 7
data_path = '0.5 ' + self._file + ' 0.5 ' + self._file2
formatter = MixtureFormatter(data_path, max_samples=max_samples)
ds = formatter.load_dataset()
self.assertEqual(len(ds), max_samples)
self.assertEqual(list(ds.features.keys()), ['text', 'meta'])

if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/format/test_unify_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ def test_hetero_meta(self):
'author': 'xxx'
}
}]
unified_sample_list = ds.to_list()
self.assertEqual(unified_sample_list, sample)

# test nested and missing field for the following cases:
# 1. first row, then column
unified_sample_first = ds[0]
Expand Down
7 changes: 6 additions & 1 deletion tools/postprocess/data_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def parse_args():
'size of each shard won\'t larger than the '
'export_shard_size')

parser.add_argument('--max_samples',
type=int,
default=None,
help='Number of samples of mixed dataset.')

parser.add_argument('--num_proc',
type=int,
default=4,
Expand All @@ -58,7 +63,7 @@ def run_mixture():
"""
args = parse_args()
data_path = ' '.join(args.data_path)
formatter = load_formatter(data_path)
formatter = load_formatter(data_path, max_samples=args.max_samples)
dataset = formatter.load_dataset(args.num_proc)
exporter = Exporter(export_path=args.export_path,
export_shard_size=args.export_shard_size,
Expand Down

0 comments on commit 413be3b

Please sign in to comment.