From 1aa6147fe3d96caec04e6bb8998c5a227f726ae3 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 23 Jun 2023 23:14:06 -0700 Subject: [PATCH] Basic error handling for yaml. (#27145) --- sdks/python/apache_beam/transforms/core.py | 94 +++++++++++++- sdks/python/apache_beam/yaml/yaml_mapping.py | 48 ++++++- .../python/apache_beam/yaml/yaml_transform.py | 37 +++++- .../apache_beam/yaml/yaml_transform_test.py | 121 +++++++++++++++++- 4 files changed, 288 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 95c2617a5a08..026625d7805d 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2182,6 +2182,9 @@ def expand(self, pcoll): *self._args, **self._kwargs).with_outputs( self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) + #TODO(BEAM-18957): Fix when type inference supports tagged outputs. + result[self._main_tag].element_type = self._fn.infer_output_type( + pcoll.element_type) if self._threshold < 1.0: @@ -2244,6 +2247,81 @@ def process(self, *args, **kwargs): traceback.format_exception(*sys.exc_info())))) +# Idea adapted from https://github.com/tosun-si/asgarde. +# TODO(robertwb): Consider how this could fit into the public API. +# TODO(robertwb): Generalize to all PValue types. +class _PValueWithErrors(object): + """This wraps a PCollection such that transforms can be chained in a linear + manner while still accumulating any errors.""" + def __init__(self, pcoll, exception_handling_args, upstream_errors=()): + self._pcoll = pcoll + self._exception_handling_args = exception_handling_args + self._upstream_errors = upstream_errors + + def main_output_tag(self): + return self._exception_handling_args.get('main_tag', 'good') + + def error_output_tag(self): + return self._exception_handling_args.get('dead_letter_tag', 'bad') + + def __or__(self, transform): + return self.apply(transform) + + def apply(self, transform): + result = self._pcoll | transform.with_exception_handling( + **self._exception_handling_args) + if result[self.main_output_tag()].element_type == typehints.Any: + result[self.main_output_tag()].element_type = transform.infer_output_type( + self._pcoll.element_type) + # TODO(BEAM-18957): Add support for tagged type hints. + result[self.error_output_tag()].element_type = typehints.Any + return _PValueWithErrors( + result[self.main_output_tag()], + self._exception_handling_args, + self._upstream_errors + (result[self.error_output_tag()], )) + + def accumulated_errors(self): + if len(self._upstream_errors) == 1: + return self._upstream_errors[0] + else: + return self._upstream_errors | Flatten() + + def as_result(self, error_post_processing=None): + return { + self.main_output_tag(): self._pcoll, + self.error_output_tag(): self.accumulated_errors() + if error_post_processing is None else self.accumulated_errors() + | error_post_processing, + } + + +class _MaybePValueWithErrors(object): + """This is like _PValueWithErrors, but only wraps values if + exception_handling_args is non-trivial. It is useful for handling + error-catching and non-error-catching code in a uniform manner. + """ + def __init__(self, pvalue, exception_handling_args=None): + if isinstance(pvalue, _PValueWithErrors): + assert exception_handling_args is None + self._pvalue = pvalue + elif exception_handling_args is None: + self._pvalue = pvalue + else: + self._pvalue = _PValueWithErrors(pvalue, exception_handling_args) + + def __or__(self, transform): + return self.apply(transform) + + def apply(self, transform): + return _MaybePValueWithErrors(self._pvalue | transform) + + def as_result(self, error_post_processing=None): + if isinstance(self._pvalue, _PValueWithErrors): + return self._pvalue.as_result(error_post_processing) + else: + return self._pvalue + + class _SubprocessDoFn(DoFn): """Process method run in a subprocess, turning hard crashes into exceptions. """ @@ -3232,14 +3310,21 @@ def __init__( _expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args) ] + [(name, _expr_to_callable(expr, name)) for (name, expr) in kwargs.items()] + self._exception_handling_args = None + + def with_exception_handling(self, **kwargs): + self._exception_handling_args = kwargs + return self def default_label(self): return 'ToRows(%s)' % ', '.join(name for name, _ in self._fields) def expand(self, pcoll): - return pcoll | Map( - lambda x: pvalue.Row(**{name: expr(x) - for name, expr in self._fields})) + return ( + _MaybePValueWithErrors(pcoll, self._exception_handling_args) | Map( + lambda x: pvalue.Row( + **{name: expr(x) + for name, expr in self._fields}))).as_result() def infer_output_type(self, input_type): return row_type.RowTypeConstraint.from_fields([ @@ -3430,6 +3515,9 @@ def process( new_windows = self.windowing.windowfn.assign(context) yield WindowedValue(element, context.timestamp, new_windows) + def infer_output_type(self, input_type): + return input_type + def __init__( self, windowfn, # type: typing.Union[Windowing, WindowFn] diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 3b25eb78a39d..5ea3fcb02fb1 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -62,6 +62,7 @@ class _Explode(beam.PTransform): def __init__(self, fields, cross_product): self._fields = fields self._cross_product = cross_product + self._exception_handling_args = None def expand(self, pcoll): all_fields = [ @@ -86,11 +87,15 @@ def explode_zip(base, fields): copy[field] = values[ix] yield beam.Row(**copy) - return pcoll | beam.FlatMap( - lambda row: ( - explode_cross_product if self._cross_product else explode_zip)( - {name: getattr(row, name) for name in all_fields}, # yapf break - to_explode)) + return ( + beam.core._MaybePValueWithErrors( + pcoll, self._exception_handling_args) + | beam.FlatMap( + lambda row: ( + explode_cross_product if self._cross_product else explode_zip)( + {name: getattr(row, name) for name in all_fields}, # yapf + to_explode)) + ).as_result() def infer_output_type(self, input_type): return row_type.RowTypeConstraint.from_fields([( @@ -98,16 +103,37 @@ def infer_output_type(self, input_type): trivial_inference.element_type(typ) if name in self._fields else typ) for (name, typ) in named_fields_from_element_type(input_type)]) + def with_exception_handling(self, **kwargs): + # It's possible there's an error in iteration... + self._exception_handling_args = kwargs + return self + # TODO(yaml): Should Filter and Explode be distinct operations from Project? # We'll want these per-language. @beam.ptransform.ptransform_fn def _PythonProjectionTransform( - pcoll, *, fields, keep=None, explode=(), cross_product=True): + pcoll, + *, + fields, + keep=None, + explode=(), + cross_product=True, + error_handling=None): original_fields = [ name for (name, _) in named_fields_from_element_type(pcoll.element_type) ] + if error_handling is None: + error_handling_args = None + else: + error_handling_args = { + 'dead_letter_tag' if k == 'output' else k: v + for (k, v) in error_handling.items() + } + + pcoll = beam.core._MaybePValueWithErrors(pcoll, error_handling_args) + if keep: if isinstance(keep, str) and keep in original_fields: keep_fn = lambda row: getattr(row, keep) @@ -131,7 +157,11 @@ def _PythonProjectionTransform( else: result = projected - return result + return result.as_result( + beam.MapTuple( + lambda element, + exc_info: beam.Row( + element=element, msg=str(exc_info[1]), stack=str(exc_info[2])))) @beam.ptransform.ptransform_fn @@ -146,6 +176,7 @@ def MapToFields( append=False, drop=(), language=None, + error_handling=None, **language_keywords): if isinstance(explode, str): @@ -192,6 +223,8 @@ def MapToFields( language = "python" if language in ("sql", "calcite"): + if error_handling: + raise ValueError('Error handling unsupported for sql.') selects = [f'{expr} AS {name}' for (name, expr) in fields.items()] query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION" if keep: @@ -215,6 +248,7 @@ def MapToFields( 'keep': keep, 'explode': explode, 'cross_product': cross_product, + 'error_handling': error_handling, }, **language_keywords }, [pcoll]) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index ebc9eb6c066c..169957e03371 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -192,6 +192,11 @@ def get_pcollection(self, name): if len(outputs) == 1: return only_element(outputs.values()) else: + error_output = self._transforms_by_uuid[self.get_transform_id( + name)].get('error_handling', {}).get('output') + if error_output and error_output in outputs and len(outputs) == 2: + return next( + output for tag, output in outputs.items() if tag != error_output) raise ValueError( f'Ambiguous output at line {SafeLineLoader.get_line(name)}: ' f'{name} has outputs {list(outputs.keys())}') @@ -655,6 +660,34 @@ def all_inputs(t): return dict(spec, transforms=new_transforms) +def ensure_transforms_have_types(spec): + if 'type' not in spec: + raise ValueError(f'Missing type specification in {identify_object(spec)}') + return spec + + +def ensure_errors_consumed(spec): + if spec['type'] == 'composite': + scope = LightweightScope(spec['transforms']) + to_handle = {} + consumed = set( + scope.get_transform_id_and_output_name(output) + for output in spec['output'].values()) + for t in spec['transforms']: + if 'error_handling' in t: + if 'output' not in t['error_handling']: + raise ValueError( + f'Missing output in error_handling of {identify_object(t)}') + to_handle[t['__uuid__'], t['error_handling']['output']] = t + for _, input in t['input'].items(): + if input not in spec['input']: + consumed.add(scope.get_transform_id_and_output_name(input)) + for error_pcoll, t in to_handle.items(): + if error_pcoll not in consumed: + raise ValueError(f'Unconsumed error output for {identify_object(t)}.') + return spec + + def preprocess(spec, verbose=False): if verbose: pprint.pprint(spec) @@ -666,10 +699,12 @@ def apply(phase, spec): spec, transforms=[apply(phase, t) for t in spec['transforms']]) return spec - for phase in [preprocess_source_sink, + for phase in [ensure_transforms_have_types, + preprocess_source_sink, preprocess_chain, normalize_inputs_outputs, preprocess_flattened_inputs, + ensure_errors_consumed, preprocess_windowing]: spec = apply(phase, spec) if verbose: diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index b036ea0a0371..4ccc50da7573 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -226,11 +226,130 @@ def expand(self, pcoll): return pcoll | beam.CombineGlobally(sum).without_defaults() +class SizeLimiter(beam.PTransform): + def __init__(self, limit, error_handling): + self._limit = limit + self._error_handling = error_handling + + def expand(self, pcoll): + def raise_on_big(element): + if len(element) > self._limit: + raise ValueError(element) + else: + return element + + good, bad = pcoll | beam.Map(raise_on_big).with_exception_handling() + return {'small_elements': good, self._error_handling['output']: bad} + + TEST_PROVIDERS = { - 'CreateTimestamped': CreateTimestamped, 'SumGlobally': SumGlobally + 'CreateTimestamped': CreateTimestamped, + 'SumGlobally': SumGlobally, + 'SizeLimiter': SizeLimiter, } +class ErrorHandlingTest(unittest.TestCase): + def test_error_handling_outputs(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + elements: ['a', 'b', 'biiiiig'] + - type: SizeLimiter + limit: 5 + input: Create + error_handling: + output: errors + - name: TrimErrors + type: PyMap + input: SizeLimiter.errors + fn: "lambda x: x[1][1]" + output: + good: SizeLimiter + bad: TrimErrors + ''', + providers=TEST_PROVIDERS) + assert_that(result['good'], equal_to(['a', 'b']), label="CheckGood") + assert_that(result['bad'], equal_to(["ValueError('biiiiig')"])) + + def test_must_handle_error_output(self): + with self.assertRaisesRegex(Exception, 'Unconsumed error output .*line 6'): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + elements: ['a', 'b', 'biiiiig'] + - type: SizeLimiter + limit: 5 + input: Create + error_handling: + output: errors + ''', + providers=TEST_PROVIDERS) + + def test_mapping_errors(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + elements: [0, 1, 2, 4] + - type: PyMap + name: ToRow + input: Create + fn: "lambda x: beam.Row(num=x, str='a' * x or 'bbb')" + - type: MapToFields + name: MapWithErrorHandling + input: ToRow + language: python + fields: + num: num + inverse: float(1 / num) + keep: + str[1] >= 'a' + error_handling: + output: errors + - type: PyMap + name: TrimErrors + input: MapWithErrorHandling.errors + fn: "lambda x: x.msg" + - type: MapToFields + name: Sum + language: python + input: MapWithErrorHandling + append: True + fields: + sum: num + inverse + output: + good: Sum + bad: TrimErrors + ''', + providers=TEST_PROVIDERS) + assert_that( + result['good'], + equal_to([ + beam.Row(num=2, inverse=.5, sum=2.5), + beam.Row(num=4, inverse=.25, sum=4.25) + ]), + label="CheckGood") + assert_that( + result['bad'], + equal_to([ + "IndexError('string index out of range')", # from the filter + "ZeroDivisionError('division by zero')", # from the mapping + ]), + label='CheckErrors') + + class YamlWindowingTest(unittest.TestCase): def test_explicit_window_into(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(