Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample level error catching #325

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,76 @@

import pandas as pd
import pyarrow as pa
from loguru import logger

from data_juicer.utils.mm_utils import size_to_bytes
from data_juicer.utils.registry import Registry

OPERATORS = Registry('Operators')


def convert_list_dict_to_dict_list(samples):
# reconstruct samples from "list of dicts" to "dict of lists"
keys = samples[0].keys()
res_samples = {}
for key in keys:
res_samples[key] = [s[key] for s in samples]
return res_samples


def convert_dict_list_to_list_dict(samples):
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
keys = list(samples.keys())
# take any key, since they should be of same length
for i in range(len(samples[keys[0]])):
reconstructed_samples.append({key: samples[key][i] for key in samples})
return reconstructed_samples


def catch_exception_mapper_process(method):
"""
For mapper sample level fault torelerance.
"""

def wrapper(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except Exception as e:
samples = args[0]
logger.error(
f'An error occurred in mapper operation when processing'
f'sample {samples}, {type(e)}: {e}')
return {}

return wrapper


def catch_exception_mapper_process_single(method):
"""
For mapper process_single,
turn it into batch_size = 1, and enable fault torelerance.
"""

def wrapper(self, *args, **kwargs):
try:
args = list(args)
samples = args[0]
sample = convert_dict_list_to_list_dict(samples)[0]
args[0] = sample
args = tuple(args)
res_sample = method(self, *args, **kwargs)
return convert_list_dict_to_dict_list([res_sample])
except Exception as e:
samples = args[0]
logger.error(
f'An error occurred in mapper operation when processing'
f'sample {samples}, {type(e)}: {e}')
return {}

return wrapper


class OP:

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -87,6 +150,7 @@ def ray_batch_mapper_wrapper(samples, fn):
return pa.Table.from_pandas(res)


# @mapper_fault_tolerance
class Mapper(OP):

def __init__(self, *args, **kwargs):
Expand All @@ -105,7 +169,9 @@ def __init__(self, *args, **kwargs):
super(Mapper, self).__init__(*args, **kwargs)

# In default, it's a normal OP instead of batched OP
self._batched_op = kwargs.get('batched_op', False)
# self._batched_op = kwargs.get('batched_op', False)
# Aftet the refactor, we want all ops to be batched OP by default
self._batched_op = kwargs.get('batched_op', True)

def process(self, sample):
"""
Expand Down Expand Up @@ -134,6 +200,7 @@ def __call__(self, sample):
return self.process(sample)


# @filter_fault_tolerance
class Filter(OP):

def __init__(self, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.logger_utils import HiddenPrints

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single

OP_NAME = 'audio_ffmpeg_wrapped_mapper'

Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(
self.capture_stderr = capture_stderr
self.overwrite_output = overwrite_output

def process(self, sample):
# there is no audio in this sample
@catch_exception_mapper_process_single
def process(self, sample): # there is no audio in this sample
if self.audio_key not in sample or not sample[self.audio_key]:
return sample

Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/clean_copyright_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import regex as re

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('clean_copyright_mapper')
Expand All @@ -23,8 +23,8 @@ def __init__(self, *args, **kwargs):
self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/')
self.cpat = re.compile('copyright', re.IGNORECASE)

@catch_exception_mapper_process_single
def process(self, sample):

r = self.pat.search(sample[self.text_key])
if r:
# found one, now see if it contains "copyright", if so strip it
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/clean_email_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import regex as re

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('clean_email_mapper')
Expand Down Expand Up @@ -28,8 +28,8 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):

self.repl = repl

@catch_exception_mapper_process_single
def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample

Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/clean_html_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from data_juicer.utils.availability_utils import AvailabilityChecking

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single

OP_NAME = 'clean_html_mapper'

Expand All @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

@catch_exception_mapper_process_single
def process(self, sample):

def _clean_html(raw_html):
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/clean_ip_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import regex as re

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('clean_ip_mapper')
Expand Down Expand Up @@ -32,8 +32,8 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.pattern = pattern[2:-1]
self.repl = repl

@catch_exception_mapper_process_single
def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample

Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/clean_links_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# --------------------------------------------------------
import regex as re

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('clean_links_mapper')
Expand Down Expand Up @@ -38,8 +38,8 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.pattern = pattern[2:-1]
self.repl = repl

@catch_exception_mapper_process_single
def process(self, sample):

if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
return sample

Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/expand_macro_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import regex as re

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('expand_macro_mapper')
Expand Down Expand Up @@ -55,6 +55,7 @@ def _build_non_arg_macros_dict(self, file_content):
macros[macro_name] = macro_val
return macros

@catch_exception_mapper_process_single
def process(self, sample):
non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key])

Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/fix_unicode_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from data_juicer.utils.availability_utils import AvailabilityChecking

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single

OP_NAME = 'fix_unicode_mapper'

Expand Down Expand Up @@ -33,6 +33,7 @@ def __init__(self, normalization: str = None, *args, **kwargs):
'supported. Can only be one of '
'["NFC", "NFKC", "NFD", "NFKD"]')

@catch_exception_mapper_process_single
def process(self, sample):
sample[self.text_key] = ftfy.fix_text(sample[self.text_key],
normalization=self.normalization)
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/image_blur_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.mm_utils import load_data_with_context, load_image

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_blur_mapper'
Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(self,
else:
self.blur = ImageFilter.GaussianBlur(radius)

@catch_exception_mapper_process_single
def process(self, sample, context=False):
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
remove_non_special_tokens,
remove_special_tokens)

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process
from ..op_fusion import LOADED_IMAGES

SYSTEM_PROMPTS = {
Expand Down Expand Up @@ -244,6 +244,7 @@ def _process_single_sample(self, sample):

return [generated_sample]

@catch_exception_mapper_process
def process(self, samples):
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/image_captioning_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_captioning_mapper'
Expand Down Expand Up @@ -272,6 +272,7 @@ def _reduce_captions_per_image(self, chunk,
generated_text_candidates_single_chunk[max_index])
return new_generated_text_per_chunk

@catch_exception_mapper_process
def process(self, samples, rank=None):
"""
Note:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/image_diffusion_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
load_image, remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_diffusion_mapper'
Expand Down Expand Up @@ -209,6 +209,7 @@ def _process_single_sample(self, ori_sample, rank=None, context=False):

return generated_samples

@catch_exception_mapper_process
def process(self, samples, rank=None, context=False):
"""
Note:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/image_face_blur_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from data_juicer.utils.mm_utils import (load_data_with_context, load_image,
pil_to_opencv)

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_face_blur_mapper'
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self,
# Initialize face detector
self.detector = dlib.get_frontal_face_detector()

@catch_exception_mapper_process_single
def process(self, sample, context=False):
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/nlpaug_en_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from data_juicer.utils.availability_utils import AvailabilityChecking

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process

OP_NAME = 'nlpaug_en_mapper'

Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(self,
else:
self.aug = aug_pipeline

@catch_exception_mapper_process
def process(self, samples):
# no augmentation methods are opened
if len(self.aug) == 0:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/nlpcda_zh_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.logger_utils import HiddenPrints

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process

OP_NAME = 'nlpcda_zh_mapper'

Expand Down Expand Up @@ -128,6 +128,7 @@ def __init__(self,
self.aug_pipeline.append(
nlpcda.EquivalentChar(create_num=create_num))

@catch_exception_mapper_process
def process(self, samples):
# no augmentation methods are opened
if len(self.aug_pipeline) == 0:
Expand Down
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/punctuation_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/bigscience-workshop/data-preparation
# --------------------------------------------------------

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, Mapper, catch_exception_mapper_process_single


@OPERATORS.register_module('punctuation_normalization_mapper')
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs):
'►': '-',
}

@catch_exception_mapper_process_single
def process(self, sample):
sample[self.text_key] = ''.join([
self.punctuation_unicode.get(c, c) for c in sample[self.text_key]
Expand Down
Loading
Loading