diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml b/sdks/python/apache_beam/yaml/pipeline.schema.yaml index 40f576c1618b7..f68a7306d9410 100644 --- a/sdks/python/apache_beam/yaml/pipeline.schema.yaml +++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml @@ -154,6 +154,27 @@ $defs: - transforms - config + providerInclude: + # TODO(robertwb): Consider enumerating the provider types along with + # the arguments they accept/expect (possibly in a separate schema file). + type: object + properties: + include: { type: string } + __line__: {} + __uuid__: {} + additionalProperties: false + required: + - include + + providerOrProviderInclude: + if: + properties: + include {} + then: + $ref: '#/$defs/providerInclude' + else: + $ref: '#/$defs/provider' + type: object properties: pipeline: @@ -185,7 +206,7 @@ properties: providers: type: array items: - $ref: '#/$defs/provider' + $ref: '#/$defs/providerOrProviderInclude' options: type: object required: diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 794cad0ec7f47..46ec0700ee27b 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -34,6 +34,7 @@ from typing import Callable from typing import Dict from typing import Iterable +from typing import Iterator from typing import Mapping from typing import Optional @@ -45,6 +46,7 @@ import apache_beam.dataframe.io import apache_beam.io import apache_beam.transforms.util +from apache_beam.io.filesystems import FileSystems from apache_beam.portability.api import schema_pb2 from apache_beam.runners import pipeline_context from apache_beam.testing.util import assert_that @@ -222,7 +224,10 @@ def provider_from_spec(cls, spec): config['version'] = beam_version if type in cls._provider_types: try: - return cls._provider_types[type](urns, **config) + result = cls._provider_types[type](urns, **config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: spec + return result except Exception as exn: raise ValueError( f'Unable to instantiate provider of type {type} ' @@ -1153,18 +1158,44 @@ def cache_artifacts(self): self._underlying_provider.cache_artifacts() -def parse_providers(provider_specs): - providers = collections.defaultdict(list) +def flatten_included_provider_specs( + provider_specs: Iterable[Mapping]) -> Iterator[Mapping]: + from apache_beam.yaml.yaml_transform import SafeLineLoader for provider_spec in provider_specs: - provider = ExternalProvider.provider_from_spec(provider_spec) - for transform_type in provider.provided_transforms(): - providers[transform_type].append(provider) - # TODO: Do this better. - provider.to_json = lambda result=provider_spec: result - return providers + if 'include' in provider_spec: + if len(SafeLineLoader.strip_metadata(provider_spec)) != 1: + raise ValueError( + f"When using include, it must be the only parameter: " + f"{provider_spec} " + f"at line {{SafeLineLoader.get_line(provider_spec)}}") + include_uri = provider_spec['include'] + try: + with urllib.request.urlopen(include_uri) as response: + content = response.read() + except (ValueError, urllib.error.URLError) as exn: + if 'unknown url type' in str(exn): + with FileSystems.open(include_uri) as fin: + content = fin.read() + else: + raise + included_providers = yaml.load(content, Loader=SafeLineLoader) + if not isinstance(included_providers, list): + raise ValueError( + f"Included file {include_uri} must be a list of Providers " + f"at line {{SafeLineLoader.get_line(provider_spec)}}") + yield from flatten_included_provider_specs(included_providers) + else: + yield provider_spec + + +def parse_providers(provider_specs: Iterable[Mapping]) -> Iterable[Provider]: + return [ + ExternalProvider.provider_from_spec(provider_spec) + for provider_spec in flatten_included_provider_specs(provider_specs) + ] -def merge_providers(*provider_sets): +def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]: result = collections.defaultdict(list) for provider_set in provider_sets: if isinstance(provider_set, Provider): diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py index ec71422fd161b..5a30c7d140bfe 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py @@ -16,9 +16,15 @@ # import logging +import os +import tempfile import unittest +import yaml + +from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import YamlProviders +from apache_beam.yaml.yaml_transform import SafeLineLoader class WindowIntoTest(unittest.TestCase): @@ -63,6 +69,80 @@ def test_parse_duration_with_missing_value(self): self.parse_duration('s', 'size') +class ProviderParsingTest(unittest.TestCase): + + INLINE_PROVIDER = {'type': 'TEST', 'name': 'INLINED'} + INCLUDED_PROVIDER = {'type': 'TEST', 'name': 'INCLUDED'} + EXTRA_PROVIDER = {'type': 'TEST', 'name': 'EXTRA'} + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.TemporaryDirectory() + cls.to_include = os.path.join(cls.tempdir.name, 'providers.yaml') + with open(cls.to_include, 'w') as fout: + yaml.dump([cls.INCLUDED_PROVIDER], fout) + cls.to_include_nested = os.path.join( + cls.tempdir.name, 'nested_providers.yaml') + with open(cls.to_include_nested, 'w') as fout: + yaml.dump([{'include': cls.to_include}, cls.EXTRA_PROVIDER], fout) + + @classmethod + def tearDownClass(cls): + cls.tempdir.cleanup() + + def test_include_file(self): + flattened = [ + SafeLineLoader.strip_metadata(spec) + for spec in yaml_provider.flatten_included_provider_specs([ + self.INLINE_PROVIDER, + { + 'include': self.to_include + }, + ]) + ] + + self.assertEqual([ + self.INLINE_PROVIDER, + self.INCLUDED_PROVIDER, + ], + flattened) + + def test_include_url(self): + flattened = [ + SafeLineLoader.strip_metadata(spec) + for spec in yaml_provider.flatten_included_provider_specs([ + self.INLINE_PROVIDER, + { + 'include': 'file:///' + self.to_include + }, + ]) + ] + + self.assertEqual([ + self.INLINE_PROVIDER, + self.INCLUDED_PROVIDER, + ], + flattened) + + def test_nested_include(self): + flattened = [ + SafeLineLoader.strip_metadata(spec) + for spec in yaml_provider.flatten_included_provider_specs([ + self.INLINE_PROVIDER, + { + 'include': self.to_include_nested + }, + ]) + ] + + self.assertEqual([ + self.INLINE_PROVIDER, + self.INCLUDED_PROVIDER, + self.EXTRA_PROVIDER, + ], + flattened) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index fd265c42cf73a..297143e94cc2c 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -1037,10 +1037,6 @@ def expand_pipeline( # Calling expand directly to avoid outer layer of nesting. return YamlTransform( pipeline_as_composite(pipeline_spec['pipeline']), - { - **yaml_provider.parse_providers(pipeline_spec.get('providers', [])), - **{ - key: yaml_provider.as_provider_list(key, value) - for (key, value) in (providers or {}).items() - } - }).expand(beam.pvalue.PBegin(pipeline)) + yaml_provider.merge_providers( + pipeline_spec.get('providers', []), providers or + {})).expand(beam.pvalue.PBegin(pipeline))