From a8205d789a8f9da975d051ec90c2bc6a3ca7b1b6 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 7 Jun 2024 11:35:55 -0700 Subject: [PATCH 1/2] Allow flags to be used as jinja template variables. --- sdks/python/apache_beam/yaml/main.py | 40 +++++++++++++++++++++++ sdks/python/apache_beam/yaml/main_test.py | 30 +++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index 6c87a1ba7e68..30bf6f76017a 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -29,6 +29,45 @@ from apache_beam.yaml import yaml_transform +def _preparse_jinja_flags(argv): + """Promotes any flags to --jinja_variables based on --jinja_variable_flags. + + This is to facilitate tools (such as dataflow templates) that must pass + options as un-nested flags. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--jinja_variable_flags', + default=[], + type=lambda s: s.split(','), + help='A list of flag names that should be used as jinja variables.') + parser.add_argument( + '--jinja_variables', + default={}, + type=json.loads, + help='A json dict of variables used when invoking the jinja preprocessor ' + 'on the provided yaml pipeline.') + jinja_args, other_args = parser.parse_known_args(argv) + if not jinja_args.jinja_variable_flags: + return argv + + jinja_variable_parser = argparse.ArgumentParser() + for flag_name in jinja_args.jinja_variable_flags: + jinja_variable_parser.add_argument('--' + flag_name) + jinja_flag_variables, pipeline_args = jinja_variable_parser.parse_known_args( + other_args) + jinja_args.jinja_variables.update( + ** + {k: v + for (k, v) in vars(jinja_flag_variables).items() if v is not None}) + if jinja_args.jinja_variables: + pipeline_args = pipeline_args + [ + '--jinja_variables=' + json.dumps(jinja_args.jinja_variables) + ] + + return pipeline_args + + def _configure_parser(argv): parser = argparse.ArgumentParser() parser.add_argument( @@ -90,6 +129,7 @@ def _fix_xlang_instant_coding(): def run(argv=None): + argv = _preparse_jinja_flags(argv) known_args, pipeline_args = _configure_parser(argv) pipeline_template = _pipeline_spec_from_args(known_args) pipeline_yaml = ( # keep formatting diff --git a/sdks/python/apache_beam/yaml/main_test.py b/sdks/python/apache_beam/yaml/main_test.py index b10c788bccaa..1a3da6443b72 100644 --- a/sdks/python/apache_beam/yaml/main_test.py +++ b/sdks/python/apache_beam/yaml/main_test.py @@ -70,6 +70,36 @@ def test_jinja_variables(self): with open(glob.glob(out_path + '*')[0], 'rt') as fin: self.assertEqual(fin.read().strip(), 'my_line') + def test_jinja_variable_flags(self): + with tempfile.TemporaryDirectory() as tmpdir: + out_path = os.path.join(tmpdir, 'out.txt') + main.run([ + '--yaml_pipeline', + TEST_PIPELINE.replace('PATH', out_path).replace('ELEMENT', '{{var}}'), + '--jinja_variable_flags=var', + '--var=my_line' + ]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual(fin.read().strip(), 'my_line') + + def test_preparse_jinja_flags(self): + argv = [ + '--jinja_variables={"from_vars": 1, "from_both": 2}', + '--jinja_variable_flags=from_both,from_flag,from_missing_flag', + '--from_both=30', + '--from_flag=40', + '--another_arg=foo', + 'pos_arg', + ] + self.assertCountEqual( + main._preparse_jinja_flags(argv), + [ + '--jinja_variables=' + '{"from_vars": 1, "from_both": "30", "from_flag": "40"}', + '--another_arg=foo', + 'pos_arg', + ]) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From b4c83a69cefa0dc194571c9f94651cb7165784da Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 10 Jun 2024 10:09:34 -0700 Subject: [PATCH 2/2] Less confusing name for argument parser. --- sdks/python/apache_beam/yaml/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index 30bf6f76017a..781f1b0ba95a 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -68,7 +68,7 @@ def _preparse_jinja_flags(argv): return pipeline_args -def _configure_parser(argv): +def _parse_arguments(argv): parser = argparse.ArgumentParser() parser.add_argument( '--yaml_pipeline', @@ -130,7 +130,7 @@ def _fix_xlang_instant_coding(): def run(argv=None): argv = _preparse_jinja_flags(argv) - known_args, pipeline_args = _configure_parser(argv) + known_args, pipeline_args = _parse_arguments(argv) pipeline_template = _pipeline_spec_from_args(known_args) pipeline_yaml = ( # keep formatting jinja2.Environment(