diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index 6c87a1ba7e68a..781f1b0ba95a4 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -29,7 +29,46 @@ from apache_beam.yaml import yaml_transform -def _configure_parser(argv): +def _preparse_jinja_flags(argv): + """Promotes any flags to --jinja_variables based on --jinja_variable_flags. + + This is to facilitate tools (such as dataflow templates) that must pass + options as un-nested flags. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--jinja_variable_flags', + default=[], + type=lambda s: s.split(','), + help='A list of flag names that should be used as jinja variables.') + parser.add_argument( + '--jinja_variables', + default={}, + type=json.loads, + help='A json dict of variables used when invoking the jinja preprocessor ' + 'on the provided yaml pipeline.') + jinja_args, other_args = parser.parse_known_args(argv) + if not jinja_args.jinja_variable_flags: + return argv + + jinja_variable_parser = argparse.ArgumentParser() + for flag_name in jinja_args.jinja_variable_flags: + jinja_variable_parser.add_argument('--' + flag_name) + jinja_flag_variables, pipeline_args = jinja_variable_parser.parse_known_args( + other_args) + jinja_args.jinja_variables.update( + ** + {k: v + for (k, v) in vars(jinja_flag_variables).items() if v is not None}) + if jinja_args.jinja_variables: + pipeline_args = pipeline_args + [ + '--jinja_variables=' + json.dumps(jinja_args.jinja_variables) + ] + + return pipeline_args + + +def _parse_arguments(argv): parser = argparse.ArgumentParser() parser.add_argument( '--yaml_pipeline', @@ -90,7 +129,8 @@ def _fix_xlang_instant_coding(): def run(argv=None): - known_args, pipeline_args = _configure_parser(argv) + argv = _preparse_jinja_flags(argv) + known_args, pipeline_args = _parse_arguments(argv) pipeline_template = _pipeline_spec_from_args(known_args) pipeline_yaml = ( # keep formatting jinja2.Environment( diff --git a/sdks/python/apache_beam/yaml/main_test.py b/sdks/python/apache_beam/yaml/main_test.py index b10c788bccaa1..1a3da6443b722 100644 --- a/sdks/python/apache_beam/yaml/main_test.py +++ b/sdks/python/apache_beam/yaml/main_test.py @@ -70,6 +70,36 @@ def test_jinja_variables(self): with open(glob.glob(out_path + '*')[0], 'rt') as fin: self.assertEqual(fin.read().strip(), 'my_line') + def test_jinja_variable_flags(self): + with tempfile.TemporaryDirectory() as tmpdir: + out_path = os.path.join(tmpdir, 'out.txt') + main.run([ + '--yaml_pipeline', + TEST_PIPELINE.replace('PATH', out_path).replace('ELEMENT', '{{var}}'), + '--jinja_variable_flags=var', + '--var=my_line' + ]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual(fin.read().strip(), 'my_line') + + def test_preparse_jinja_flags(self): + argv = [ + '--jinja_variables={"from_vars": 1, "from_both": 2}', + '--jinja_variable_flags=from_both,from_flag,from_missing_flag', + '--from_both=30', + '--from_flag=40', + '--another_arg=foo', + 'pos_arg', + ] + self.assertCountEqual( + main._preparse_jinja_flags(argv), + [ + '--jinja_variables=' + '{"from_vars": 1, "from_both": "30", "from_flag": "40"}', + '--another_arg=foo', + 'pos_arg', + ]) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)