Skip to content

Commit

Permalink
Distributed data processing with Ray (#21)
Browse files Browse the repository at this point in the history
* add ray executor
  • Loading branch information
pan-x-c authored Sep 18, 2023
1 parent 0eb6cb0 commit 9a55b8f
Show file tree
Hide file tree
Showing 19 changed files with 232 additions and 25 deletions.
15 changes: 14 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def init_configs(args=None):
type=str,
default='hello_world',
help='Name of your data process project.')
parser.add_argument(
'--executor_type',
type=str,
default='default',
choices=['default', 'ray'],
help='Type of executor, support "default" or "ray" for now.'
)
parser.add_argument(
'--dataset_path',
type=str,
Expand Down Expand Up @@ -178,6 +185,12 @@ def init_configs(args=None):
default=False,
help='Whether to save all stats to only one file. Only used in '
'Analysis.')
parser.add_argument(
'--ray_address',
type=str,
default='auto',
help='The address of the Ray cluster.'
)

# add all parameters of the registered ops class to the parser,
# and these op parameters can be modified through the command line,
Expand Down Expand Up @@ -271,7 +284,7 @@ def init_setup_from_cfg(cfg):
timestamp = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
cfg.timestamp = timestamp
logfile_name = timestamp + '.txt'
setup_logger(save_dir=log_dir, filename=logfile_name)
setup_logger(save_dir=log_dir, filename=logfile_name, redirect=cfg.executor_type=='default')

# whether or not to use cache management
# disabling the cache or using checkpoint explicitly will turn off the
Expand Down
1 change: 1 addition & 0 deletions data_juicer/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .analyser import Analyser
from .data import NestedDataset
from .executor import Executor
from .ray_executor import RayExecutor
from .exporter import Exporter
from .tracer import Tracer
86 changes: 86 additions & 0 deletions data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

from loguru import logger
from data_juicer.config import init_configs
from data_juicer.ops import (Filter, Mapper, load_ops)
from data_juicer.utils.constant import Fields

import ray
import ray.data as rd


class RayExecutor:
"""
Executor based on Ray [Experimental].
Run Data-Juicer data processing in a distributed cluster.
1. Only support Filter and Mapper operators for now.
2. Only support loading `.json` files.
2. Advanced functions such as checkpoint, tracer are not supported.
"""

def __init__(self, cfg=None):
"""
Initialization method.
:param cfg: optional config dict.
"""
self.cfg = init_configs() if cfg is None else cfg

self.work_dir = self.cfg.work_dir

self.ops = None
# init ray
logger.info('Initing Ray ...')
ray.init(self.cfg.ray_address)
self.process_list = self.cfg.process


def run(self, load_data_np=None):
"""
Running the dataset process pipeline.
:param load_data_np: number of workers when loading the dataset.
:return: processed dataset.
"""
# 1. load data
logger.info('Loading dataset with Ray...')
dataset = rd.read_json(self.cfg.dataset_path)

# 2. extract processes
logger.info('Preparing process operators...')
self.process_list, self.ops = load_ops(self.cfg.process,
self.cfg.op_fusion)

# 3. data process
# - If tracer is open, trace each op after it's processed
# - If checkpoint is open, clean the cache files after each process
if Fields.stats not in dataset.columns(fetch_if_missing=False):
logger.info(f'columns {dataset.columns(fetch_if_missing=False)}')
dataset = dataset.add_column(Fields.stats, lambda df: [{}] * len(df))
logger.info('Processing data...')
for op_cfg, op in zip(self.process_list, self.ops):
op_name, _ = list(op_cfg.items())[0]
try:
if isinstance(op, Mapper):
dataset = dataset.map(op.process)
elif isinstance(op, Filter):
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
else:
logger.error('Ray executor only support Filter and Mapper OPs for now')
raise NotImplementedError
except: # noqa: E722
logger.error(f'An error occurred during Op [{op_name}].')
import traceback
traceback.print_exc()
exit(1)

# clean up cache files and record processed ops
logger.info(f'Op [{op_name}] Done. Left '
f'{dataset.count()} samples.')

# 4. data export
logger.info('Exporting dataset to disk...')
dataset.write_json(self.cfg.export_path, force_ascii=False)
return dataset
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jsonargparse.typing import PositiveFloat

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter
from ..common import get_words_from_document
Expand Down Expand Up @@ -54,7 +54,7 @@ def compute_stats(self, sample):
alpha_count = sum(
map(lambda char: 1
if char.isalpha() else 0, sample[self.text_key]))
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, model_type='huggingface')
token_count = len(
get_words_from_document(
sample[self.text_key],
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jsonargparse.typing import ClosedUnitInterval, List

from data_juicer.utils.constant import Fields, StatsKeys, InterVars
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ...utils.asset_utils import ASSET_DIR, load_words_asset
from ..base_op import OPERATORS, Filter
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self,
self.words_aug_group_sizes = words_aug_group_sizes
self.words_aug_join_char = words_aug_join_char
self.model_key = None
self.lang = lang

self.FLAGGED_WORDS = load_words_asset(words_dir=flagged_words_dir,
words_type='flagged_words')
Expand All @@ -78,7 +79,8 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, lang=self.lang,
model_type='sentencepiece')
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/language_id_score_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from loguru import logger

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter

Expand Down Expand Up @@ -38,7 +38,7 @@ def compute_stats(self, sample):
return sample

text = sample[self.text_key].lower().replace('\n', ' ')
ft_model = MODEL_ZOO.get(self.model_key, None)
ft_model = get_model(self.model_key, lang=self.lang, model_type='fasttext')
if ft_model is None:
err_msg = 'Model not loaded. Please retry later.'
logger.error(err_msg)
Expand Down
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jsonargparse.typing import PositiveFloat

from data_juicer.utils.constant import Fields, StatsKeys, InterVars
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_WORDS
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(self,
"""
super().__init__(*args, **kwargs)
self.max_ppl = max_ppl
self.lang = lang
self.sp_model_key = prepare_model(lang=lang,
model_type='sentencepiece')
self.kl_model_key = prepare_model(lang=lang, model_type='kenlm')
Expand All @@ -48,7 +49,7 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = MODEL_ZOO.get(self.sp_model_key, None)
tokenizer = get_model(self.sp_model_key, self.lang, 'sentencepiece')
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand All @@ -57,7 +58,7 @@ def compute_stats(self, sample, context=False):
text = ' '.join(words)
# compute perplexity
logits, length = 0, 0
kenlm_model = MODEL_ZOO.get(self.kl_model_key, None)
kenlm_model = get_model(self.kl_model_key, self.lang, 'kenlm')
for line in text.splitlines():
logits += kenlm_model.score(line)
length += (len(line.split()) + 1)
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/filter/stopwords_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from data_juicer.utils.asset_utils import ASSET_DIR, load_words_asset
from data_juicer.utils.constant import Fields, StatsKeys, InterVars
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_WORDS
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(self,
self.words_aug_group_sizes = words_aug_group_sizes
self.words_aug_join_char = words_aug_join_char
self.model_key = None
self.lang = lang

self.STOPWORDS = load_words_asset(words_dir=stopwords_dir,
words_type='stopwords')
Expand All @@ -76,7 +77,8 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, lang=self.lang,
model_type='sentencepiece')
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/filter/word_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys, InterVars
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_WORDS
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self,
self.min_num = min_num
self.max_num = max_num
self.model_key = None
self.lang = lang

if tokenization:
self.model_key = prepare_model(lang=lang,
Expand All @@ -56,7 +57,8 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, lang=self.lang,
model_type='sentencepiece')
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/filter/word_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jsonargparse.typing import ClosedUnitInterval, PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys, InterVars
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_WORDS
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.model_key = None
self.lang = lang

if tokenization:
self.model_key = prepare_model(lang=lang,
Expand All @@ -62,7 +63,8 @@ def compute_stats(self, sample, context=False):
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, lang=self.lang,
model_type='sentencepiece')
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jsonargparse.typing import List

from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Mapper
from ..common import (SPECIAL_CHARACTERS, get_words_from_document,
Expand Down Expand Up @@ -32,6 +32,7 @@ def __init__(self,
super().__init__(*args, **kwargs)
self.tokenization = tokenization
self.substrings = substrings
self.lang = lang
if tokenization:
self.model_key = prepare_model(lang=lang,
model_type='sentencepiece')
Expand All @@ -43,7 +44,7 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings):

def process(self, sample):
if self.tokenization:
tokenizer = MODEL_ZOO.get(self.model_key, None)
tokenizer = get_model(self.model_key, lang=self.lang, model_type='sentencepiece')
sentences = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/mapper/sentence_split_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from data_juicer.utils.model_utils import MODEL_ZOO, prepare_model
from data_juicer.utils.model_utils import prepare_model, get_model

from ..base_op import OPERATORS, Mapper
from ..common import get_sentences_from_document
Expand All @@ -17,11 +17,12 @@ def __init__(self, lang: str = 'en', *args, **kwargs):
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.lang = lang
self.model_key = prepare_model(lang=lang, model_type='nltk')

def process(self, sample):

nltk_model = MODEL_ZOO.get(self.model_key, None)
nltk_model = get_model(self.model_key, lang=self.lang, model_type='nltk')
sample[self.text_key] = get_sentences_from_document(
sample[self.text_key],
model_func=nltk_model.tokenize if nltk_model else None)
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/utils/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ def get_log_file_path():
return handler._sink._file.name


def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o'):
def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o', redirect=True):
"""
Setup logger for training and testing.
:param save_dir: location to save log file
:param distributed_rank: device rank when multi-gpu environment
:param filename: log file name to save
:param mode: log file write mode, `append` or `override`. default is `o`.
:param redirect: whether to redirect system output
:return: logger instance.
"""
global LOGGER_SETUP
Expand Down Expand Up @@ -128,7 +129,8 @@ def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o'):
logger.add(save_file)

# redirect stdout/stderr to loguru
redirect_sys_output('INFO')
if redirect:
redirect_sys_output('INFO')
LOGGER_SETUP = True

class HiddenPrints:
Expand Down
11 changes: 11 additions & 0 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,14 @@ def prepare_model(lang='en', model_type='sentencepiece', model_key=None):
else:
MODEL_ZOO[model_key] = model_func(model_name, lang)
return model_key


def get_model(model_key, lang='en', model_type='sentencepiece'):
"""
Get a model or a tokenizer from MODEL_ZOO.
:param model_key: name of the model or tokenzier
"""
if model_key not in MODEL_ZOO:
prepare_model(lang=lang, model_type=model_type, model_key=model_key)
return MODEL_ZOO.get(model_key, None)
Loading

0 comments on commit 9a55b8f

Please sign in to comment.