Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/multimodal format & basic processing pipeline for multimodal datasets support #64

Merged
merged 13 commits into from
Nov 13, 2023
Merged
10 changes: 10 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ trace_num: 10 # number of samples
op_fusion: false # whether to fuse operators that share the same intermediate variables automatically. Op fusion might reduce the memory requirements slightly but speed up the whole process.
cache_compress: null # The compression method of the cache file, which can be specified in ['gzip', 'zstd', 'lz4']. If this parameter is None, the cache file will not be compressed. We recommend you turn on this argument when your input dataset is larger than tens of GB and your disk space is not enough.

# for multimodal data processing
image_key: 'images' # Key name of field to store the list of sample image paths.
image_special_token: '<__dj__image>' # The special token that represents an image in the text. In default, it's "<__dj__image>". You can specify your own special token according to your input dataset.

eoc_special_token: '<|__dj__eoc|>' # The special token that represents the end of a chunk in the text. In default, it's "<|__dj__eoc|>". You can specify your own special token according to your input dataset.

# for distributed processing
executor_type: default # Type of executor, support "default" or "ray" for now.
ray_address: auto # The address of the Ray cluster.
Expand Down Expand Up @@ -110,6 +116,10 @@ process:
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- image_aspect_ratio_filter: # filter samples according to the aspect ratios of images (a fraction of width by height, r=w/h) in them
min_ratio: 0.333 # the min aspect ratio of filter range
max_ratio: 3.0 # the max aspect ratio of filter range
any_or_all: any # keep this sample when any/all images meet the filter condition
- language_id_score_filter: # filter text in specific language with language scores larger than a specific max value
lang: en # keep text in what language
min_score: 0.8 # the min language scores to filter text
Expand Down
43 changes: 42 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from data_juicer.ops.base_op import OPERATORS
from data_juicer.utils.logger_utils import setup_logger
from data_juicer.utils.mm_utils import SpecialTokens


def init_configs(args=None):
Expand Down Expand Up @@ -102,6 +103,25 @@ def init_configs(args=None):
'requiring multiple keys, users can specify the op multiple '
'times. We will only use the first key of `text_keys` when you '
'set multiple keys.')
parser.add_argument(
'--image_key',
type=str,
default='images',
help='Key name of field to store the list of sample image paths.')
parser.add_argument(
'--image_special_token',
type=str,
default=SpecialTokens.image,
help='The special token that represents an image in the text. In '
'default, it\'s "<__dj__image>". You can specify your own special'
' token according to your input dataset.')
parser.add_argument(
'--eoc_special_token',
type=str,
default=SpecialTokens.eoc,
help='The special token that represents the end of a chunk in the '
'text. In default, it\'s "<|__dj__eoc|>". You can specify your '
'own special token according to your input dataset.')
parser.add_argument(
'--suffixes',
type=Union[str, List[str], Tuple[str]],
Expand Down Expand Up @@ -289,6 +309,19 @@ def init_setup_from_cfg(cfg):
filename=logfile_name,
redirect=cfg.executor_type == 'default')

# check and get dataset dir
if os.path.exists(cfg.dataset_path):
if os.path.isdir(cfg.dataset_path):
cfg.dataset_dir = os.path.abspath(cfg.dataset_path)
else:
cfg.dataset_dir = os.path.abspath(
os.path.dirname(cfg.dataset_path))
else:
logger.error(f'Input dataset_path [{cfg.dataset_path}] is invalid. '
f'Please check and retry.')
raise ValueError(f'Input dataset_path [{cfg.dataset_path}] is '
f'invalid. Please check and retry.')

# whether or not to use cache management
# disabling the cache or using checkpoint explicitly will turn off the
# cache management.
Expand Down Expand Up @@ -334,6 +367,10 @@ def init_setup_from_cfg(cfg):
cfg.add_suffix = True
break

# update special tokens
SpecialTokens.image = cfg.image_special_token
SpecialTokens.eoc = cfg.eoc_special_token

# Apply text_key modification during initializing configs
# users can freely specify text_key for different ops using `text_key`
# otherwise, set arg text_key of each op to text_keys
Expand All @@ -345,9 +382,13 @@ def init_setup_from_cfg(cfg):
for op_name in op:
args = op[op_name]
if args is None:
args = {'text_key': text_key}
args = {
'text_key': text_key,
'image_key': cfg.image_key,
}
elif args['text_key'] is None:
args['text_key'] = text_key
args['image_key'] = cfg.image_key
op[op_name] = args

return cfg
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/core/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run(self, load_data_np=None):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np)
dataset = self.formatter.load_dataset(load_data_np, self.cfg)

# extract processes
logger.info('Preparing process operators...')
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run(self, load_data_np=None):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np)
dataset = self.formatter.load_dataset(load_data_np, self.cfg)

# 2. extract processes
logger.info('Preparing process operators...')
Expand Down
17 changes: 15 additions & 2 deletions data_juicer/core/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def export(self, dataset):
@staticmethod
def to_jsonl(dataset, export_path, num_proc=1, **kwargs):
"""
Export method for json/jsonl target files.
Export method for jsonl target files.

:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
Expand All @@ -186,6 +186,19 @@ def to_jsonl(dataset, export_path, num_proc=1, **kwargs):
"""
dataset.to_json(export_path, force_ascii=False, num_proc=num_proc)

@staticmethod
def to_json(dataset, export_path, num_proc=1, **kwargs):
"""
Export method for json target files.

:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param num_proc: the number of processes used to export the dataset.
:param kwargs: extra arguments.
:return:
"""
dataset.to_json(export_path, force_ascii=False, num_proc=num_proc, lines=False)

@staticmethod
def to_parquet(dataset, export_path, **kwargs):
"""
Expand All @@ -208,6 +221,6 @@ def _router():
"""
return {
'jsonl': Exporter.to_jsonl,
'json': Exporter.to_jsonl,
'json': Exporter.to_json,
'parquet': Exporter.to_parquet,
}
65 changes: 52 additions & 13 deletions data_juicer/format/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def __init__(
self.data_files = find_files_with_suffix(dataset_path, suffixes)
self.add_suffix = add_suffix

def load_dataset(self, num_proc: int = 1) -> Dataset:
def load_dataset(self,
num_proc: int = 1,
global_cfg=None) -> Dataset:
"""
Load a dataset from dataset file or dataset directory, and unify its
format.
Expand All @@ -76,7 +78,8 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
concatenate_datasets([ds for _, ds in datasets.items()]))
ds = unify_format(datasets,
text_keys=self.text_keys,
num_proc=num_proc)
num_proc=num_proc,
global_cfg=global_cfg)
return ds


Expand All @@ -100,7 +103,9 @@ def __init__(self,
self.text_keys = text_keys
self.kwargs = kwargs

def load_dataset(self, num_proc: int = 1) -> Dataset:
def load_dataset(self,
num_proc: int = 1,
global_cfg=None) -> Dataset:
"""
Load a dataset from HuggingFace, and unify its format.

Expand All @@ -112,7 +117,10 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
split='train',
num_proc=num_proc,
**self.kwargs)
ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc)
ds = unify_format(ds,
text_keys=self.text_keys,
num_proc=num_proc,
global_cfg=global_cfg)
return ds


Expand All @@ -137,6 +145,7 @@ def unify_format(
dataset: Dataset,
text_keys: Union[List[str], str] = 'text',
num_proc: int = 1,
global_cfg=None,
) -> Dataset:
"""
Get an unified internal format, conduct the following modifications.
Expand Down Expand Up @@ -201,12 +210,40 @@ def non_empty_text(sample, target_keys):
fn_kwargs={'target_keys': text_keys})
logger.info(f'{len(dataset)} samples left after filtering empty text.')

# 3. add Fields.stats field
# TODO:
# this is a temp solution,
# it will occur errors when only call mapper ops
# dataset = dataset.add_column( \
# name=Fields.stats, column=[{}] * dataset.num_rows)
# 3. convert relative paths to absolute paths
if global_cfg:
logger.info('Converting relative paths in the dataset to their '
'absolute version. (Based on the directory of input '
'dataset file)')
ds_dir = global_cfg.dataset_dir
image_key = global_cfg.image_key

# function to convert relative paths to absolute paths
def rel2abs(sample, path_keys, dataset_dir):
for path_key in path_keys:
if path_key not in sample:
continue
paths = sample[path_key]
if not paths:
continue
new_paths = [os.path.join(dataset_dir, path)
for path in paths if not os.path.isabs(path)]
sample[path_key] = new_paths
return sample

dataset = dataset.map(rel2abs,
num_proc=num_proc,
fn_kwargs={
'path_keys': [
image_key,
],
'dataset_dir': ds_dir
})
else:
logger.warning(f'No global config passed into unify_format function. '
f'Relative paths in the dataset might not be converted '
f'to their absolute versions. Data of other modalities '
f'might not be able to find by Data-Juicer.')

return dataset

Expand Down Expand Up @@ -262,6 +299,8 @@ def load_formatter(dataset_path,

# no data
else:
raise ValueError('Can not found local data or huggingface '
'dataset-hub for your given path: '
f'{dataset_path} and suffixes: {suffixes}')
raise ValueError(f'Unable to load the dataset from [{dataset_path}]. '
f'It might be because Data-Juicer doesn\'t support '
f'the format of this dataset, or the path of this '
f'dataset is incorrect.Please check if it\'s a valid '
f'dataset path and retry.')
5 changes: 3 additions & 2 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,17 @@ def _random_sample(self, dataset, weight=1.0, seed=None):
return dataset
return dataset.shuffle(seed=seed).select(range(num_samples))

def load_dataset(self, num_proc: int = 1) -> Dataset:
def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Load a mixed dataset.

:param num_proc: number of processes when loading the dataset
:param global_cfg: the global cfg used in consequent processes,
:return: mixed dataset
"""
dataset_list = []
for weight, formatter in zip(self.weights, self.formatters):
dataset = formatter.load_dataset(num_proc)
dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self._random_sample(dataset, weight)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)} with weight {weight}')
Expand Down
8 changes: 6 additions & 2 deletions data_juicer/format/text_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def __init__(self,
self.dataset_path = dataset_path
self.add_suffix = add_suffix

def load_dataset(self, num_proc: int = 1) -> Dataset:
def load_dataset(self,
num_proc: int = 1,
global_cfg=None) -> Dataset:
"""
Load a dataset from local text-type files.

:param num_proc: number of processes when loading the dataset
:param global_cfg: the global cfg used in consequent processes,
:return: unified_format_dataset.
"""
# extract text to cache directory
Expand Down Expand Up @@ -154,4 +157,5 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
datasets = concatenate_datasets([ds for _, ds in datasets.items()])
return unify_format(datasets,
text_keys=self.text_keys,
num_proc=num_proc)
num_proc=num_proc,
global_cfg=global_cfg)
Loading