From f0079293456835acb68721e54d88c7de1c570a11 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 14 Jan 2025 13:28:17 +0100 Subject: [PATCH] Update `dags reserialize` command to work with DAG bundle (#45507) * Update `dags reserialize` command to work with DAG bundle This PR also changed the `--subdir` arg to --bundle-name. * fixup! Update `dags reserialize` command to work with DAG bundle * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> * Update tests/cli/commands/remote_commands/test_dag_command.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> * fixup! Update tests/cli/commands/remote_commands/test_dag_command.py --------- Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/cli/cli_config.py | 10 ++++- .../commands/remote_commands/dag_command.py | 20 +++++++-- .../remote_commands/test_dag_command.py | 44 +++++++++++++------ 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index cb847d30db03c..56f0e1c3b4937 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -166,6 +166,14 @@ def string_lower_type(val): ), default="[AIRFLOW_HOME]/dags" if BUILD_DOCS else settings.DAGS_FOLDER, ) +ARG_BUNDLE_NAME = Arg( + ( + "-B", + "--bundle-name", + ), + help=("The name of the DAG bundle to use."), + default=None, +) ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate) ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate) ARG_OUTPUT_PATH = Arg( @@ -1225,7 +1233,7 @@ class GroupCommand(NamedTuple): ), func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_reserialize"), args=( - ARG_SUBDIR, + ARG_BUNDLE_NAME, ARG_VERBOSE, ), ), diff --git a/airflow/cli/commands/remote_commands/dag_command.py b/airflow/cli/commands/remote_commands/dag_command.py index a405e8fba2214..1e922029abff2 100644 --- a/airflow/cli/commands/remote_commands/dag_command.py +++ b/airflow/cli/commands/remote_commands/dag_command.py @@ -537,7 +537,19 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No @provide_session def dag_reserialize(args, session: Session = NEW_SESSION) -> None: """Serialize a DAG instance.""" - # TODO: AIP-66 bundle centric reserialize - raise NotImplementedError( - "AIP-66: This command is not implemented yet - use `dag-processor --num-runs 1` in the meantime." - ) + from airflow.dag_processing.bundles.manager import DagBundlesManager + + manager = DagBundlesManager() + manager.sync_bundles_to_db(session=session) + session.commit() + if args.bundle_name: + bundle = manager.get_bundle(args.bundle_name) + if not bundle: + raise SystemExit(f"Bundle {args.bundle_name} not found") + dag_bag = DagBag(bundle.path, include_examples=False) + dag_bag.sync_to_db(bundle.name, bundle_version=bundle.get_current_version(), session=session) + else: + bundles = manager.get_all_dag_bundles() + for bundle in bundles: + dag_bag = DagBag(bundle.path, include_examples=False) + dag_bag.sync_to_db(bundle.name, bundle_version=bundle.get_current_version(), session=session) diff --git a/tests/cli/commands/remote_commands/test_dag_command.py b/tests/cli/commands/remote_commands/test_dag_command.py index 0db0ac83df02f..04b0ef88104e7 100644 --- a/tests/cli/commands/remote_commands/test_dag_command.py +++ b/tests/cli/commands/remote_commands/test_dag_command.py @@ -79,7 +79,6 @@ def teardown_class(cls) -> None: def setup_method(self): clear_db_runs() # clean-up all dag run before start each test - @pytest.mark.skip("AIP-66: reserialize is not implemented yet") def test_reserialize(self, session): # Assert that there are serialized Dags serialized_dags_before_command = session.query(SerializedDagModel).all() @@ -99,8 +98,7 @@ def test_reserialize(self, session): dag_version_after_command = session.query(DagVersion).all() assert len(dag_version_after_command) - @pytest.mark.skip("AIP-66: reserialize is not implemented yet") - def test_reserialize_should_support_subdir_argument(self, session): + def test_reserialize_should_support_bundle_name_argument(self, configure_testing_dag_bundle, session): # Run clear of serialized dags session.query(DagVersion).delete() @@ -108,20 +106,38 @@ def test_reserialize_should_support_subdir_argument(self, session): serialized_dags_after_clear = session.query(SerializedDagModel).all() assert len(serialized_dags_after_clear) == 0 - # Serialize manually - dag_path = self.dagbag.dags["example_bash_operator"].fileloc - # Set default value of include_examples parameter to false - dagbag_default = list(DagBag.__init__.__defaults__) - dagbag_default[1] = False - with mock.patch( - "airflow.cli.commands.remote_commands.dag_command.DagBag.__init__.__defaults__", - tuple(dagbag_default), - ): - dag_command.dag_reserialize(self.parser.parse_args(["dags", "reserialize", "--subdir", dag_path])) + path_to_parse = TEST_DAGS_FOLDER / "test_dag_with_no_tags.py" + + with configure_testing_dag_bundle(path_to_parse): + # reserializes only the above path + dag_command.dag_reserialize( + self.parser.parse_args(["dags", "reserialize", "--bundle-name", "testing"]) + ) + + # Check serialized DAG are back + serialized_dags_after_reserialize = session.query(SerializedDagModel).all() + assert len(serialized_dags_after_reserialize) == 1 + + def test_reserialize_should_support_more_than_one_bundle(self, configure_testing_dag_bundle, session): + # Run clear of serialized dags + session.query(DagVersion).delete() + + # Assert no serialized Dags + serialized_dags_after_clear = session.query(SerializedDagModel).all() + assert len(serialized_dags_after_clear) == 0 + + path_to_parse = TEST_DAGS_FOLDER / "test_dag_with_no_tags.py" + + with configure_testing_dag_bundle(path_to_parse): + # The command will now serialize the above bundle and the example dag bundle + dag_command.dag_reserialize(self.parser.parse_args(["dags", "reserialize"])) # Check serialized DAG are back serialized_dags_after_reserialize = session.query(SerializedDagModel).all() - assert len(serialized_dags_after_reserialize) == 1 # Serialized DAG back + assert len(serialized_dags_after_reserialize) > 1 + serialized_dag_ids = [dag.dag_id for dag in serialized_dags_after_reserialize] + assert "test_dag_with_no_tags" in serialized_dag_ids + assert "example_bash_operator" in serialized_dag_ids def test_show_dag_dependencies_print(self): with contextlib.redirect_stdout(StringIO()) as temp_stdout: