Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Isolate logic that distributes Components output after run #7845

Merged
merged 13 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,122 @@ def _init_graph(self):
for node in self.graph.nodes:
self.graph.nodes[node]["visits"] = 0

def _dequeue_components_that_received_no_input(
self,
component_name: str,
component_result: Dict[str, Any],
to_run: List[Tuple[str, Component]],
waiting_for_input: List[Tuple[str, Component]],
):
"""
Removes Components that didn't receive any input from the list of Components to run.

We can't run those Components if they didn't receive any input, even if it's optional.
This is mainly useful for Components that have conditional outputs.

:param component_name: Name of the Component that created the output
:param component_result: The output of the Component
:param to_run: Queue of Components to run
:param waiting_for_input: Queue of Components waiting for input
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
"""
instance: Component = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_output__._sockets_dict.items(): # type: ignore
if socket_name in component_result:
continue
for receiver in socket.receivers:
receiver_instance: Component = self.graph.nodes[receiver]["instance"]
pair = (receiver, receiver_instance)
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)

def _distribute_output(
self,
component_name: str,
component_result: Dict[str, Any],
inputs_by_component: Dict[str, Dict[str, Any]],
to_run: List[Tuple[str, Component]],
waiting_for_input: List[Tuple[str, Component]],
) -> Dict[str, Any]:
"""
Distributes the output of a Component to the next Components that need it.

This also updates the queues that keep track of which Components are ready to run and which are waiting for input.

:param component_name: Name of the Component that created the output
:param component_result: The output of the Component
:paramt inputs_by_component: The current state of the inputs divided by Component name
:param to_run: Queue of Components to run
:param waiting_for_input: Queue of Components waiting for input

:return: The updated output of the Component without the keys that were distributed to other Components
"""
# We keep track of which keys to remove from component_result at the end of the loop.
# This is done after the output has been distributed to the next components, so that
# we're sure all components that need this output have received it.
to_remove_from_component_result = set()

for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
sender_socket: OutputSocket = connection["from_socket"]
receiver_socket: InputSocket = connection["to_socket"]

if sender_socket.name not in component_result:
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
# This output wasn't created by the sender, nothing we can do.
#
# Some Components might have conditional outputs, so we need to check if they actually returned
# some output while iterating over their output sockets.
#
# A perfect example of this would be the ConditionalRouter, which will have an output for each
# condition it has been initialized with.
# Though it will return only one output at a time.
continue

if receiver_name not in inputs_by_component:
inputs_by_component[receiver_name] = {}

# We keep track of the keys that were distributed to other Components.
# This key will be removed from component_result at the end of the loop.
to_remove_from_component_result.add(sender_socket.name)

value = component_result[sender_socket.name]

if receiver_socket.is_variadic:
# Usually Component inputs can only be received from one sender, the Variadic type allows
# instead to receive inputs from multiple senders.
#
# To keep track of all the inputs received internally we always store them in a list.
if receiver_socket.name not in inputs_by_component[receiver_name]:
# Create the list if it doesn't exist
inputs_by_component[receiver_name][receiver_socket.name] = []
else:
# Check if the value is actually a list
assert isinstance(inputs_by_component[receiver_name][receiver_socket.name], list)
inputs_by_component[receiver_name][receiver_socket.name].append(value)
else:
inputs_by_component[receiver_name][receiver_socket.name] = value

receiver = self.graph.nodes[receiver_name]["instance"]
pair = (receiver_name, receiver)

is_greedy = getattr(receiver, "__haystack_is_greedy__", False)
if receiver_socket.is_variadic and is_greedy:
# If the receiver is greedy, we can run it as soon as possible.
# First we remove it from the status lists it's in if it's there or we risk running it multiple times.
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)
to_run.append(pair)
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

if pair not in waiting_for_input and pair not in to_run:
# Queue up the Component that received this input to run, only if it's not already waiting
# for input or already ready to run.
to_run.append(pair)

# Returns the output without the keys that were distributed to other Components
return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result}


def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
Expand Down
60 changes: 9 additions & 51 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
res: Dict[str, Any] = instance.run(**inputs)
self.graph.nodes[name]["visits"] += 1

# After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed
for socket in instance.__haystack_input__._sockets_dict.values(): # type: ignore
if socket.name not in inputs:
continue
if socket.is_variadic:
inputs[socket.name] = []
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(res, Mapping):
raise PipelineRuntimeError(
f"Component '{name}' didn't return a dictionary. "
Expand Down Expand Up @@ -253,57 +260,8 @@ def run(self, word: str):
# This happens when a component was put in the waiting list but we reached it from another edge.
waiting_for_input.remove((name, comp))

# We keep track of which keys to remove from res at the end of the loop.
# This is done after the output has been distributed to the next components, so that
# we're sure all components that need this output have received it.
to_remove_from_res = set()
for sender_component_name, receiver_component_name, edge_data in self.graph.edges(data=True):
if receiver_component_name == name and edge_data["to_socket"].is_variadic:
# Delete variadic inputs that were already consumed
last_inputs[name][edge_data["to_socket"].name] = []
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

if name != sender_component_name:
continue

pair = (receiver_component_name, self.graph.nodes[receiver_component_name]["instance"])
if edge_data["from_socket"].name not in res:
# The component didn't produce any output for this socket.
# We can't run the receiver, let's remove it from the list of components to run
# or we risk running it if it's in those lists.
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)
continue

if receiver_component_name not in last_inputs:
last_inputs[receiver_component_name] = {}
to_remove_from_res.add(edge_data["from_socket"].name)
value = res[edge_data["from_socket"].name]

if edge_data["to_socket"].is_variadic:
if edge_data["to_socket"].name not in last_inputs[receiver_component_name]:
last_inputs[receiver_component_name][edge_data["to_socket"].name] = []
# Add to the list of variadic inputs
last_inputs[receiver_component_name][edge_data["to_socket"].name].append(value)
else:
last_inputs[receiver_component_name][edge_data["to_socket"].name] = value

is_greedy = pair[1].__haystack_is_greedy__
is_variadic = edge_data["to_socket"].is_variadic
if is_variadic and is_greedy:
# If the receiver is greedy, we can run it right away.
# First we remove it from the lists it's in if it's there or we risk running it multiple times.
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)
to_run.append(pair)

if pair not in waiting_for_input and pair not in to_run:
to_run.append(pair)

res = {k: v for k, v in res.items() if k not in to_remove_from_res}
self._dequeue_components_that_received_no_input(name, res, to_run, waiting_for_input)
res = self._distribute_output(name, res, last_inputs, to_run, waiting_for_input)

if len(res) > 0:
final_outputs[name] = res
Expand Down
64 changes: 64 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,15 @@ def test__run_component(self, spying_tracer, caplog):

assert caplog.messages == ["Running component document_builder"]

def test__run_component_with_variadic_input(self):
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()

pipe = Pipeline()
pipe.add_component("document_joiner", document_joiner)
inputs = {"docs": [Document(content="doc1"), Document(content="doc2")]}
pipe._run_component("document_joiner", inputs)
assert inputs == {"docs": []}

def test__component_has_enough_inputs_to_run(self):
sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]})()
pipe = Pipeline()
Expand All @@ -1055,3 +1064,58 @@ def test__component_has_enough_inputs_to_run(self):
assert pipe._component_has_enough_inputs_to_run(
"sentence_builder", {"sentence_builder": {"words": ["blah blah"]}}
)

def test__dequeue_components_that_received_no_input(self):
sentence_builder = component_class(
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
)()

pipe = Pipeline()
pipe.add_component("sentence_builder", sentence_builder)
pipe.add_component("document_builder", document_builder)
pipe.connect("sentence_builder.text", "document_builder.text")

to_run = [("document_builder", document_builder)]
waiting_for_input = [("document_builder", document_builder)]
pipe._dequeue_components_that_received_no_input("sentence_builder", {}, to_run, waiting_for_input)
assert to_run == []
assert waiting_for_input == []

def test__distribute_output(self):
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document}
)()
document_cleaner = component_class(
"DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document}
)()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()

pipe = Pipeline()
pipe.add_component("document_builder", document_builder)
pipe.add_component("document_cleaner", document_cleaner)
pipe.add_component("document_joiner", document_joiner)
pipe.connect("document_builder.doc", "document_cleaner.doc")
pipe.connect("document_builder.another_doc", "document_joiner.docs")

inputs = {"document_builder": {"text": "some text"}}
to_run = []
waiting_for_input = [("document_joiner", document_joiner)]
res = pipe._distribute_output(
"document_builder",
{"doc": Document("some text"), "another_doc": Document()},
inputs,
to_run,
waiting_for_input,
)

assert res == {}
assert inputs == {
"document_builder": {"text": "some text"},
"document_cleaner": {"doc": Document("some text")},
"document_joiner": {"docs": [Document()]},
}
assert to_run == [("document_cleaner", document_cleaner)]
assert waiting_for_input == [("document_joiner", document_joiner)]