Skip to content

Commit

Permalink
Update dags reserialize command to work with DAG bundle (apache#45507)
Browse files Browse the repository at this point in the history
* Update `dags reserialize` command to work with DAG bundle

This PR also changed the `--subdir` arg to --bundle-name.

* fixup! Update `dags reserialize` command to work with DAG bundle

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>

* Update tests/cli/commands/remote_commands/test_dag_command.py

Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>

* fixup! Update tests/cli/commands/remote_commands/test_dag_command.py

---------

Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>
  • Loading branch information
ephraimbuddy and jedcunningham authored Jan 14, 2025
1 parent d200f9a commit f007929
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
10 changes: 9 additions & 1 deletion airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def string_lower_type(val):
),
default="[AIRFLOW_HOME]/dags" if BUILD_DOCS else settings.DAGS_FOLDER,
)
ARG_BUNDLE_NAME = Arg(
(
"-B",
"--bundle-name",
),
help=("The name of the DAG bundle to use."),
default=None,
)
ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate)
ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate)
ARG_OUTPUT_PATH = Arg(
Expand Down Expand Up @@ -1225,7 +1233,7 @@ class GroupCommand(NamedTuple):
),
func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_reserialize"),
args=(
ARG_SUBDIR,
ARG_BUNDLE_NAME,
ARG_VERBOSE,
),
),
Expand Down
20 changes: 16 additions & 4 deletions airflow/cli/commands/remote_commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,19 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
@provide_session
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
"""Serialize a DAG instance."""
# TODO: AIP-66 bundle centric reserialize
raise NotImplementedError(
"AIP-66: This command is not implemented yet - use `dag-processor --num-runs 1` in the meantime."
)
from airflow.dag_processing.bundles.manager import DagBundlesManager

manager = DagBundlesManager()
manager.sync_bundles_to_db(session=session)
session.commit()
if args.bundle_name:
bundle = manager.get_bundle(args.bundle_name)
if not bundle:
raise SystemExit(f"Bundle {args.bundle_name} not found")
dag_bag = DagBag(bundle.path, include_examples=False)
dag_bag.sync_to_db(bundle.name, bundle_version=bundle.get_current_version(), session=session)
else:
bundles = manager.get_all_dag_bundles()
for bundle in bundles:
dag_bag = DagBag(bundle.path, include_examples=False)
dag_bag.sync_to_db(bundle.name, bundle_version=bundle.get_current_version(), session=session)
44 changes: 30 additions & 14 deletions tests/cli/commands/remote_commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def teardown_class(cls) -> None:
def setup_method(self):
clear_db_runs() # clean-up all dag run before start each test

@pytest.mark.skip("AIP-66: reserialize is not implemented yet")
def test_reserialize(self, session):
# Assert that there are serialized Dags
serialized_dags_before_command = session.query(SerializedDagModel).all()
Expand All @@ -99,29 +98,46 @@ def test_reserialize(self, session):
dag_version_after_command = session.query(DagVersion).all()
assert len(dag_version_after_command)

@pytest.mark.skip("AIP-66: reserialize is not implemented yet")
def test_reserialize_should_support_subdir_argument(self, session):
def test_reserialize_should_support_bundle_name_argument(self, configure_testing_dag_bundle, session):
# Run clear of serialized dags
session.query(DagVersion).delete()

# Assert no serialized Dags
serialized_dags_after_clear = session.query(SerializedDagModel).all()
assert len(serialized_dags_after_clear) == 0

# Serialize manually
dag_path = self.dagbag.dags["example_bash_operator"].fileloc
# Set default value of include_examples parameter to false
dagbag_default = list(DagBag.__init__.__defaults__)
dagbag_default[1] = False
with mock.patch(
"airflow.cli.commands.remote_commands.dag_command.DagBag.__init__.__defaults__",
tuple(dagbag_default),
):
dag_command.dag_reserialize(self.parser.parse_args(["dags", "reserialize", "--subdir", dag_path]))
path_to_parse = TEST_DAGS_FOLDER / "test_dag_with_no_tags.py"

with configure_testing_dag_bundle(path_to_parse):
# reserializes only the above path
dag_command.dag_reserialize(
self.parser.parse_args(["dags", "reserialize", "--bundle-name", "testing"])
)

# Check serialized DAG are back
serialized_dags_after_reserialize = session.query(SerializedDagModel).all()
assert len(serialized_dags_after_reserialize) == 1

def test_reserialize_should_support_more_than_one_bundle(self, configure_testing_dag_bundle, session):
# Run clear of serialized dags
session.query(DagVersion).delete()

# Assert no serialized Dags
serialized_dags_after_clear = session.query(SerializedDagModel).all()
assert len(serialized_dags_after_clear) == 0

path_to_parse = TEST_DAGS_FOLDER / "test_dag_with_no_tags.py"

with configure_testing_dag_bundle(path_to_parse):
# The command will now serialize the above bundle and the example dag bundle
dag_command.dag_reserialize(self.parser.parse_args(["dags", "reserialize"]))

# Check serialized DAG are back
serialized_dags_after_reserialize = session.query(SerializedDagModel).all()
assert len(serialized_dags_after_reserialize) == 1 # Serialized DAG back
assert len(serialized_dags_after_reserialize) > 1
serialized_dag_ids = [dag.dag_id for dag in serialized_dags_after_reserialize]
assert "test_dag_with_no_tags" in serialized_dag_ids
assert "example_bash_operator" in serialized_dag_ids

def test_show_dag_dependencies_print(self):
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
Expand Down

0 comments on commit f007929

Please sign in to comment.