Skip to content

Commit

Permalink
Merge pull request #31549 Allow flags to be used as jinja template va…
Browse files Browse the repository at this point in the history
…riables.
  • Loading branch information
robertwb committed Jun 10, 2024
2 parents d560aa6 + b4c83a6 commit ab94c8f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
44 changes: 42 additions & 2 deletions sdks/python/apache_beam/yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,46 @@
from apache_beam.yaml import yaml_transform


def _configure_parser(argv):
def _preparse_jinja_flags(argv):
"""Promotes any flags to --jinja_variables based on --jinja_variable_flags.
This is to facilitate tools (such as dataflow templates) that must pass
options as un-nested flags.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--jinja_variable_flags',
default=[],
type=lambda s: s.split(','),
help='A list of flag names that should be used as jinja variables.')
parser.add_argument(
'--jinja_variables',
default={},
type=json.loads,
help='A json dict of variables used when invoking the jinja preprocessor '
'on the provided yaml pipeline.')
jinja_args, other_args = parser.parse_known_args(argv)
if not jinja_args.jinja_variable_flags:
return argv

jinja_variable_parser = argparse.ArgumentParser()
for flag_name in jinja_args.jinja_variable_flags:
jinja_variable_parser.add_argument('--' + flag_name)
jinja_flag_variables, pipeline_args = jinja_variable_parser.parse_known_args(
other_args)
jinja_args.jinja_variables.update(
**
{k: v
for (k, v) in vars(jinja_flag_variables).items() if v is not None})
if jinja_args.jinja_variables:
pipeline_args = pipeline_args + [
'--jinja_variables=' + json.dumps(jinja_args.jinja_variables)
]

return pipeline_args


def _parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
'--yaml_pipeline',
Expand Down Expand Up @@ -90,7 +129,8 @@ def _fix_xlang_instant_coding():


def run(argv=None):
known_args, pipeline_args = _configure_parser(argv)
argv = _preparse_jinja_flags(argv)
known_args, pipeline_args = _parse_arguments(argv)
pipeline_template = _pipeline_spec_from_args(known_args)
pipeline_yaml = ( # keep formatting
jinja2.Environment(
Expand Down
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/yaml/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ def test_jinja_variables(self):
with open(glob.glob(out_path + '*')[0], 'rt') as fin:
self.assertEqual(fin.read().strip(), 'my_line')

def test_jinja_variable_flags(self):
with tempfile.TemporaryDirectory() as tmpdir:
out_path = os.path.join(tmpdir, 'out.txt')
main.run([
'--yaml_pipeline',
TEST_PIPELINE.replace('PATH', out_path).replace('ELEMENT', '{{var}}'),
'--jinja_variable_flags=var',
'--var=my_line'
])
with open(glob.glob(out_path + '*')[0], 'rt') as fin:
self.assertEqual(fin.read().strip(), 'my_line')

def test_preparse_jinja_flags(self):
argv = [
'--jinja_variables={"from_vars": 1, "from_both": 2}',
'--jinja_variable_flags=from_both,from_flag,from_missing_flag',
'--from_both=30',
'--from_flag=40',
'--another_arg=foo',
'pos_arg',
]
self.assertCountEqual(
main._preparse_jinja_flags(argv),
[
'--jinja_variables='
'{"from_vars": 1, "from_both": "30", "from_flag": "40"}',
'--another_arg=foo',
'pos_arg',
])


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit ab94c8f

Please sign in to comment.