diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index f22ea078fb87eb..4759356ae124a4 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,6 +1,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator +from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent @@ -48,25 +49,35 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExcept # we remove the node maybe shortcut the answer node, so comment this code for now # there is not effect on the answer node and the workflow, when we have a better solution # we can open this code. Issues: #11542 #9560 #10638 #10564 - ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) - if "answer" in ids: - continue - else: - reachable_node_ids.extend(ids) + # ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) + # if "answer" in ids: + # continue + # else: + # reachable_node_ids.extend(ids) + + # The branch_identify parameter is added to ensure that + # only nodes in the correct logical branch are included. + ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) + reachable_node_ids.extend(ids) else: unreachable_first_node_ids.append(edge.target_node_id) for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) - def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: node_ids = [] for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id == self.graph.root_node_id: continue + # Only follow edges that match the branch_identify or have no run_condition + if edge.run_condition and edge.run_condition.branch_identify: + if not branch_identify or edge.run_condition.branch_identify != branch_identify: + continue + node_ids.append(edge.target_node_id) - node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) return node_ids def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: