Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Allow explicitly including external provider lists. #31604

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional

Expand All @@ -45,6 +46,7 @@
import apache_beam.dataframe.io
import apache_beam.io
import apache_beam.transforms.util
from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
from apache_beam.runners import pipeline_context
from apache_beam.testing.util import assert_that
Expand Down Expand Up @@ -222,7 +224,10 @@ def provider_from_spec(cls, spec):
config['version'] = beam_version
if type in cls._provider_types:
try:
return cls._provider_types[type](urns, **config)
result = cls._provider_types[type](urns, **config)
if not hasattr(result, 'to_json'):
result.to_json = lambda: spec
return result
except Exception as exn:
raise ValueError(
f'Unable to instantiate provider of type {type} '
Expand Down Expand Up @@ -1153,18 +1158,35 @@ def cache_artifacts(self):
self._underlying_provider.cache_artifacts()


def parse_providers(provider_specs):
providers = collections.defaultdict(list)
def flatten_included_provider_specs(
provider_specs: Iterable[Mapping]) -> Iterator[Mapping]:
for provider_spec in provider_specs:
provider = ExternalProvider.provider_from_spec(provider_spec)
for transform_type in provider.provided_transforms():
providers[transform_type].append(provider)
# TODO: Do this better.
provider.to_json = lambda result=provider_spec: result
return providers
if 'include' in provider_spec:
if len(provider_spec) != 1:
raise ValueError(f"Invalid provider spec: {provider_spec}")
try:
with urllib.request.urlopen(provider_spec['include']) as response:
content = response.read()
except (ValueError, urllib.error.URLError) as exn:
if 'unknown url type' in str(exn):
with FileSystems.open(provider_spec['include']) as fin:
content = fin.read()
else:
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we raise a runtime error here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage to raising a RuntimeError rather than preserving the specific error that fetching the URL produced?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicitly mention what kind of providers are supported? or maybe we can update raise ValueError(f"Invalid provider spec: {provider_spec}")? basically making the error message more actionable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I've added more explicit error messages.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

yield from flatten_included_provider_specs(
yaml.load(content, Loader=SafeLoader))
else:
yield provider_spec


def parse_providers(provider_specs: Iterable[Mapping]) -> Iterable[Provider]:
return [
ExternalProvider.provider_from_spec(provider_spec)
for provider_spec in flatten_included_provider_specs(provider_specs)
]


def merge_providers(*provider_sets):
def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]:
result = collections.defaultdict(list)
for provider_set in provider_sets:
if isinstance(provider_set, Provider):
Expand Down
73 changes: 73 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
#

import logging
import os
import tempfile
import unittest

import yaml

from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_provider import YamlProviders


Expand Down Expand Up @@ -63,6 +68,74 @@ def test_parse_duration_with_missing_value(self):
self.parse_duration('s', 'size')


class ProviderParsingTest(unittest.TestCase):

INLINE_PROVIDER = {'type': 'TEST', 'name': 'INLINED'}
INCLUDED_PROVIDER = {'type': 'TEST', 'name': 'INCLUDED'}
EXTRA_PROVIDER = {'type': 'TEST', 'name': 'EXTRA'}

@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.TemporaryDirectory()
cls.to_include = os.path.join(cls.tempdir.name, 'providers.yaml')
with open(cls.to_include, 'w') as fout:
yaml.dump([cls.INCLUDED_PROVIDER], fout)
cls.to_include_nested = os.path.join(
cls.tempdir.name, 'nested_providers.yaml')
with open(cls.to_include_nested, 'w') as fout:
yaml.dump([{'include': cls.to_include}, cls.EXTRA_PROVIDER], fout)

@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()

def test_include_file(self):
flattened = list(
yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': self.to_include
},
]))

self.assertEqual([
self.INLINE_PROVIDER,
self.INCLUDED_PROVIDER,
],
flattened)

def test_include_url(self):
flattened = list(
yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': 'file:///' + self.to_include
},
]))

self.assertEqual([
self.INLINE_PROVIDER,
self.INCLUDED_PROVIDER,
],
flattened)

def test_nested_include(self):
flattened = list(
yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': self.to_include_nested
},
]))

self.assertEqual([
self.INLINE_PROVIDER,
self.INCLUDED_PROVIDER,
self.EXTRA_PROVIDER,
],
flattened)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
10 changes: 3 additions & 7 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,10 +1037,6 @@ def expand_pipeline(
# Calling expand directly to avoid outer layer of nesting.
return YamlTransform(
pipeline_as_composite(pipeline_spec['pipeline']),
{
**yaml_provider.parse_providers(pipeline_spec.get('providers', [])),
**{
key: yaml_provider.as_provider_list(key, value)
for (key, value) in (providers or {}).items()
}
}).expand(beam.pvalue.PBegin(pipeline))
yaml_provider.merge_providers(
pipeline_spec.get('providers', []), providers or
{})).expand(beam.pvalue.PBegin(pipeline))
Loading