diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 988dfb25e4fa4..d2366c0e9e7a0 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -41,6 +41,7 @@ Callable, Collection, Container, + Generator, Iterable, Iterator, List, @@ -2627,15 +2628,25 @@ def pickle(self, session=NEW_SESSION) -> DagPickle: def tree_view(self) -> None: """Print an ASCII tree representation of the DAG.""" + for tmp in self._generate_tree_view(): + print(tmp) - def get_downstream(task, level=0): - print((" " * level * 4) + str(task)) + def _generate_tree_view(self) -> Generator[str, None, None]: + def get_downstream(task, level=0) -> Generator[str, None, None]: + yield (" " * level * 4) + str(task) level += 1 - for t in task.downstream_list: - get_downstream(t, level) - - for t in self.roots: - get_downstream(t) + for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id): + yield from get_downstream(tmp_task, level) + + for t in sorted(self.roots, key=lambda x: x.task_id): + yield from get_downstream(t) + + def get_tree_view(self) -> str: + """Return an ASCII tree representation of the DAG.""" + rst = "" + for tmp in self._generate_tree_view(): + rst += tmp + "\n" + return rst @property def task(self) -> TaskDecoratorCollection: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 1f70ba051af09..05681cfe8855d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1422,19 +1422,30 @@ def test_leaves(self): def test_tree_view(self): """Verify correctness of dag.tree_view().""" with DAG("test_dag", start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") + op1_a = EmptyOperator(task_id="t1_a") + op1_b = EmptyOperator(task_id="t1_b") op2 = EmptyOperator(task_id="t2") op3 = EmptyOperator(task_id="t3") - op1 >> op2 >> op3 + op1_b >> op2 + op1_a >> op2 >> op3 with redirect_stdout(StringIO()) as stdout: dag.tree_view() stdout = stdout.getvalue() stdout_lines = stdout.splitlines() - assert "t1" in stdout_lines[0] + assert "t1_a" in stdout_lines[0] assert "t2" in stdout_lines[1] assert "t3" in stdout_lines[2] + assert "t1_b" in stdout_lines[3] + assert dag.get_tree_view() == ( + "\n" + " \n" + " \n" + "\n" + " \n" + " \n" + ) def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): """Verify tasks with Duplicate task_id raises error"""