Skip to content

Commit

Permalink
Basic error handling for yaml. (apache#27145)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored and cushon committed May 24, 2024
1 parent 7f0154c commit 1aa6147
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 12 deletions.
94 changes: 91 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,9 @@ def expand(self, pcoll):
*self._args,
**self._kwargs).with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)

if self._threshold < 1.0:

Expand Down Expand Up @@ -2244,6 +2247,81 @@ def process(self, *args, **kwargs):
traceback.format_exception(*sys.exc_info()))))


# Idea adapted from https://github.com/tosun-si/asgarde.
# TODO(robertwb): Consider how this could fit into the public API.
# TODO(robertwb): Generalize to all PValue types.
class _PValueWithErrors(object):
"""This wraps a PCollection such that transforms can be chained in a linear
manner while still accumulating any errors."""
def __init__(self, pcoll, exception_handling_args, upstream_errors=()):
self._pcoll = pcoll
self._exception_handling_args = exception_handling_args
self._upstream_errors = upstream_errors

def main_output_tag(self):
return self._exception_handling_args.get('main_tag', 'good')

def error_output_tag(self):
return self._exception_handling_args.get('dead_letter_tag', 'bad')

def __or__(self, transform):
return self.apply(transform)

def apply(self, transform):
result = self._pcoll | transform.with_exception_handling(
**self._exception_handling_args)
if result[self.main_output_tag()].element_type == typehints.Any:
result[self.main_output_tag()].element_type = transform.infer_output_type(
self._pcoll.element_type)
# TODO(BEAM-18957): Add support for tagged type hints.
result[self.error_output_tag()].element_type = typehints.Any
return _PValueWithErrors(
result[self.main_output_tag()],
self._exception_handling_args,
self._upstream_errors + (result[self.error_output_tag()], ))

def accumulated_errors(self):
if len(self._upstream_errors) == 1:
return self._upstream_errors[0]
else:
return self._upstream_errors | Flatten()

def as_result(self, error_post_processing=None):
return {
self.main_output_tag(): self._pcoll,
self.error_output_tag(): self.accumulated_errors()
if error_post_processing is None else self.accumulated_errors()
| error_post_processing,
}


class _MaybePValueWithErrors(object):
"""This is like _PValueWithErrors, but only wraps values if
exception_handling_args is non-trivial. It is useful for handling
error-catching and non-error-catching code in a uniform manner.
"""
def __init__(self, pvalue, exception_handling_args=None):
if isinstance(pvalue, _PValueWithErrors):
assert exception_handling_args is None
self._pvalue = pvalue
elif exception_handling_args is None:
self._pvalue = pvalue
else:
self._pvalue = _PValueWithErrors(pvalue, exception_handling_args)

def __or__(self, transform):
return self.apply(transform)

def apply(self, transform):
return _MaybePValueWithErrors(self._pvalue | transform)

def as_result(self, error_post_processing=None):
if isinstance(self._pvalue, _PValueWithErrors):
return self._pvalue.as_result(error_post_processing)
else:
return self._pvalue


class _SubprocessDoFn(DoFn):
"""Process method run in a subprocess, turning hard crashes into exceptions.
"""
Expand Down Expand Up @@ -3232,14 +3310,21 @@ def __init__(
_expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
] + [(name, _expr_to_callable(expr, name))
for (name, expr) in kwargs.items()]
self._exception_handling_args = None

def with_exception_handling(self, **kwargs):
self._exception_handling_args = kwargs
return self

def default_label(self):
return 'ToRows(%s)' % ', '.join(name for name, _ in self._fields)

def expand(self, pcoll):
return pcoll | Map(
lambda x: pvalue.Row(**{name: expr(x)
for name, expr in self._fields}))
return (
_MaybePValueWithErrors(pcoll, self._exception_handling_args) | Map(
lambda x: pvalue.Row(
**{name: expr(x)
for name, expr in self._fields}))).as_result()

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([
Expand Down Expand Up @@ -3430,6 +3515,9 @@ def process(
new_windows = self.windowing.windowfn.assign(context)
yield WindowedValue(element, context.timestamp, new_windows)

def infer_output_type(self, input_type):
return input_type

def __init__(
self,
windowfn, # type: typing.Union[Windowing, WindowFn]
Expand Down
48 changes: 41 additions & 7 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class _Explode(beam.PTransform):
def __init__(self, fields, cross_product):
self._fields = fields
self._cross_product = cross_product
self._exception_handling_args = None

def expand(self, pcoll):
all_fields = [
Expand All @@ -86,28 +87,53 @@ def explode_zip(base, fields):
copy[field] = values[ix]
yield beam.Row(**copy)

return pcoll | beam.FlatMap(
lambda row: (
explode_cross_product if self._cross_product else explode_zip)(
{name: getattr(row, name) for name in all_fields}, # yapf break
to_explode))
return (
beam.core._MaybePValueWithErrors(
pcoll, self._exception_handling_args)
| beam.FlatMap(
lambda row: (
explode_cross_product if self._cross_product else explode_zip)(
{name: getattr(row, name) for name in all_fields}, # yapf
to_explode))
).as_result()

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([(
name,
trivial_inference.element_type(typ) if name in self._fields else
typ) for (name, typ) in named_fields_from_element_type(input_type)])

def with_exception_handling(self, **kwargs):
# It's possible there's an error in iteration...
self._exception_handling_args = kwargs
return self


# TODO(yaml): Should Filter and Explode be distinct operations from Project?
# We'll want these per-language.
@beam.ptransform.ptransform_fn
def _PythonProjectionTransform(
pcoll, *, fields, keep=None, explode=(), cross_product=True):
pcoll,
*,
fields,
keep=None,
explode=(),
cross_product=True,
error_handling=None):
original_fields = [
name for (name, _) in named_fields_from_element_type(pcoll.element_type)
]

if error_handling is None:
error_handling_args = None
else:
error_handling_args = {
'dead_letter_tag' if k == 'output' else k: v
for (k, v) in error_handling.items()
}

pcoll = beam.core._MaybePValueWithErrors(pcoll, error_handling_args)

if keep:
if isinstance(keep, str) and keep in original_fields:
keep_fn = lambda row: getattr(row, keep)
Expand All @@ -131,7 +157,11 @@ def _PythonProjectionTransform(
else:
result = projected

return result
return result.as_result(
beam.MapTuple(
lambda element,
exc_info: beam.Row(
element=element, msg=str(exc_info[1]), stack=str(exc_info[2]))))


@beam.ptransform.ptransform_fn
Expand All @@ -146,6 +176,7 @@ def MapToFields(
append=False,
drop=(),
language=None,
error_handling=None,
**language_keywords):

if isinstance(explode, str):
Expand Down Expand Up @@ -192,6 +223,8 @@ def MapToFields(
language = "python"

if language in ("sql", "calcite"):
if error_handling:
raise ValueError('Error handling unsupported for sql.')
selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
if keep:
Expand All @@ -215,6 +248,7 @@ def MapToFields(
'keep': keep,
'explode': explode,
'cross_product': cross_product,
'error_handling': error_handling,
},
**language_keywords
}, [pcoll])
Expand Down
37 changes: 36 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def get_pcollection(self, name):
if len(outputs) == 1:
return only_element(outputs.values())
else:
error_output = self._transforms_by_uuid[self.get_transform_id(
name)].get('error_handling', {}).get('output')
if error_output and error_output in outputs and len(outputs) == 2:
return next(
output for tag, output in outputs.items() if tag != error_output)
raise ValueError(
f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
f'{name} has outputs {list(outputs.keys())}')
Expand Down Expand Up @@ -655,6 +660,34 @@ def all_inputs(t):
return dict(spec, transforms=new_transforms)


def ensure_transforms_have_types(spec):
if 'type' not in spec:
raise ValueError(f'Missing type specification in {identify_object(spec)}')
return spec


def ensure_errors_consumed(spec):
if spec['type'] == 'composite':
scope = LightweightScope(spec['transforms'])
to_handle = {}
consumed = set(
scope.get_transform_id_and_output_name(output)
for output in spec['output'].values())
for t in spec['transforms']:
if 'error_handling' in t:
if 'output' not in t['error_handling']:
raise ValueError(
f'Missing output in error_handling of {identify_object(t)}')
to_handle[t['__uuid__'], t['error_handling']['output']] = t
for _, input in t['input'].items():
if input not in spec['input']:
consumed.add(scope.get_transform_id_and_output_name(input))
for error_pcoll, t in to_handle.items():
if error_pcoll not in consumed:
raise ValueError(f'Unconsumed error output for {identify_object(t)}.')
return spec


def preprocess(spec, verbose=False):
if verbose:
pprint.pprint(spec)
Expand All @@ -666,10 +699,12 @@ def apply(phase, spec):
spec, transforms=[apply(phase, t) for t in spec['transforms']])
return spec

for phase in [preprocess_source_sink,
for phase in [ensure_transforms_have_types,
preprocess_source_sink,
preprocess_chain,
normalize_inputs_outputs,
preprocess_flattened_inputs,
ensure_errors_consumed,
preprocess_windowing]:
spec = apply(phase, spec)
if verbose:
Expand Down
Loading

0 comments on commit 1aa6147

Please sign in to comment.