Skip to content

Commit

Permalink
Merge pull request #31684 Basic yaml-defined provider.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Jul 3, 2024
2 parents 1e873f4 + d2df083 commit 3212688
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 20 deletions.
8 changes: 5 additions & 3 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,14 @@ def option_from_runner_api(


def schema_field(
name: str, field_type: Union[schema_pb2.FieldType,
type]) -> schema_pb2.Field:
name: str,
field_type: Union[schema_pb2.FieldType, type],
description: Optional[str] = None) -> schema_pb2.Field:
return schema_pb2.Field(
name=name,
type=field_type if isinstance(field_type, schema_pb2.FieldType) else
typing_to_runner_api(field_type))
typing_to_runner_api(field_type),
description=description)


class SchemaTranslation(object):
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/yaml/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def maybe_nullable(beam_type, nullable):
fields=[
schemas.schema_field(
name,
maybe_nullable(json_type_to_beam_type(t), name not in required))
maybe_nullable(json_type_to_beam_type(t), name not in required),
description=t.get('description') if isinstance(t, dict) else None)
for (name, t) in json_schema['properties'].items()
])

Expand Down
15 changes: 2 additions & 13 deletions sdks/python/apache_beam/yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import json

import jinja2
import yaml

import apache_beam as beam
Expand Down Expand Up @@ -109,13 +108,6 @@ def _pipeline_spec_from_args(known_args):
return pipeline_yaml


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
with FileSystems.open(path) as fin:
source = fin.read().decode()
return source, path, lambda: True


@contextlib.contextmanager
def _fix_xlang_instant_coding():
# Scoped workaround for https://github.com/apache/beam/issues/28151.
Expand All @@ -132,11 +124,8 @@ def run(argv=None):
argv = _preparse_jinja_flags(argv)
known_args, pipeline_args = _parse_arguments(argv)
pipeline_template = _pipeline_spec_from_args(known_args)
pipeline_yaml = ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(pipeline_template)
.render(**known_args.jinja_variables or {}))
pipeline_yaml = yaml_transform.expand_jinja(
pipeline_template, known_args.jinja_variables or {})
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)

with _fix_xlang_instant_coding():
Expand Down
57 changes: 57 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import subprocess
import sys
import urllib.parse
import warnings
from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -61,6 +62,7 @@
from apache_beam.utils import python_callable
from apache_beam.utils import subprocess_server
from apache_beam.version import __version__ as beam_version
from apache_beam.yaml import json_utils


class Provider:
Expand Down Expand Up @@ -376,6 +378,61 @@ def _affinity(self, other: "Provider"):
return super()._affinity(other)


@ExternalProvider.register_provider_type('yaml')
class YamlProvider(Provider):
def __init__(self, transforms: Mapping[str, Mapping[str, Any]]):
if not isinstance(transforms, dict):
raise ValueError('Transform mapping must be a dict.')
self._transforms = transforms

def available(self):
return True

def cache_artifacts(self):
pass

def provided_transforms(self):
return self._transforms.keys()

def config_schema(self, type):
return json_utils.json_schema_to_beam_schema(self.json_config_schema(type))

def json_config_schema(self, type):
return dict(
type='object',
additionalProperties=False,
**self._transforms[type]['config_schema'])

def description(self, type):
return self._transforms[type].get('description')

def requires_inputs(self, type, args):
return self._transforms[type].get(
'requires_inputs', super().requires_inputs(type, args))

def create_transform(
self,
type: str,
args: Mapping[str, Any],
yaml_create_transform: Callable[
[Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
) -> beam.PTransform:
from apache_beam.yaml.yaml_transform import SafeLineLoader, YamlTransform
spec = self._transforms[type]
try:
import jsonschema
jsonschema.validate(args, self.json_config_schema(type))
except ImportError:
warnings.warn(
'Please install jsonschema '
f'for better provider validation of "{type}"')
body = spec['body']
if not isinstance(body, str):
body = yaml.safe_dump(SafeLineLoader.strip_metadata(body))
from apache_beam.yaml.yaml_transform import expand_jinja
return YamlTransform(expand_jinja(body, args))


# This is needed because type inference can't handle *args, **kwargs forwarding.
# TODO(BEAM-24755): Add support for type inference of through kwargs calls.
def fix_pycallable():
Expand Down
58 changes: 58 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@

import yaml

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_provider import YamlProviders
from apache_beam.yaml.yaml_transform import SafeLineLoader
from apache_beam.yaml.yaml_transform import YamlTransform


class WindowIntoTest(unittest.TestCase):
Expand Down Expand Up @@ -143,6 +147,60 @@ def test_nested_include(self):
flattened)


class YamlDefinedProider(unittest.TestCase):
def test_yaml_define_provider(self):
providers = '''
- type: yaml
transforms:
Range:
config_schema:
properties:
end: {type: integer}
requires_inputs: false
body: |
type: Create
config:
elements:
{% for ix in range(end) %}
- {{ix}}
{% endfor %}
Power:
config_schema:
properties:
n: {type: integer}
body:
type: chain
transforms:
- type: MapToFields
config:
language: python
append: true
fields:
power: "element**{{n}}"
'''

pipeline = '''
type: chain
transforms:
- type: Range
config:
end: 4
- type: Power
config:
n: 2
'''

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
pipeline,
providers=yaml_provider.parse_providers(
yaml.load(providers, Loader=SafeLineLoader)))
assert_that(
result | beam.Map(lambda x: (x.element, x.power)),
equal_to([(0, 0), (1, 1), (2, 4), (3, 9)]))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
25 changes: 22 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from typing import Mapping
from typing import Set

import jinja2
import yaml
from yaml.loader import SafeLoader

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.yaml import yaml_provider
Expand Down Expand Up @@ -160,9 +162,10 @@ def create_uuid(cls):
def strip_metadata(cls, spec, tagged_str=True):
if isinstance(spec, Mapping):
return {
key: cls.strip_metadata(value, tagged_str)
for key,
value in spec.items() if key not in ('__line__', '__uuid__')
cls.strip_metadata(key, tagged_str):
cls.strip_metadata(value, tagged_str)
for (key, value) in spec.items()
if key not in ('__line__', '__uuid__')
}
elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)):
return [cls.strip_metadata(value, tagged_str) for value in spec]
Expand Down Expand Up @@ -969,6 +972,22 @@ def preprocess_languages(spec):
return spec


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
with FileSystems.open(path) as fin:
source = fin.read().decode()
return source, path, lambda: True


def expand_jinja(
jinja_template: str, jinja_variables: Mapping[str, Any]) -> str:
return ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(jinja_template)
.render(**jinja_variables))


class YamlTransform(beam.PTransform):
def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value
if isinstance(spec, str):
Expand Down

0 comments on commit 3212688

Please sign in to comment.