Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SequentialChain runs the same callbacks over and over in async mode #9401

Closed
4 of 14 tasks
vamseeyarla opened this issue Aug 17, 2023 · 2 comments
Closed
4 of 14 tasks
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: models Related to LLMs or chat model modules

Comments

@vamseeyarla
Copy link
Contributor

vamseeyarla commented Aug 17, 2023

System Info

In the Async mode, SequentialChain implementation seems to run the same callbacks over and over since it is re-using the same callbacks object.

Langchain version: 0.0.264

The implementation of this aysnc route differs from the sync route and sync approach follows the right pattern of generating a new callbacks object instead of re-using the old one and thus avoiding the cascading run of callbacks at each step.

Async code

        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        ...
        for i, chain in enumerate(self.chains):
            _input = await chain.arun(_input, callbacks=callbacks)
            ...

Sync code

        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        for i, chain in enumerate(self.chains):
            _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
            ...

Notice how we are reusing the callbacks object in the Async code which will have a cascading effect as we run through the chain. It runs the same callbacks over and over resulting in issues.

CC @agola11

Who can help?

@agola11

Information

  • The official example notebooks/scripts
  • My own modified scripts

Related Components

  • LLMs/Chat Models
  • Embedding Models
  • Prompts / Prompt Templates / Prompt Selectors
  • Output Parsers
  • Document Loaders
  • Vector Stores / Retrievers
  • Memory
  • Agents / Agent Executors
  • Tools / Toolkits
  • Chains
  • Callbacks/Tracing
  • Async

Reproduction

You can write a simple sequential chain with 3 tasks with callbacks and run the code in async and notice that the callbacks run over and over as we can see in the log.

Expected behavior

We should ideally see the callbacks get run once per task.

@dosubot dosubot bot added Ɑ: models Related to LLMs or chat model modules 🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature labels Aug 17, 2023
@dosubot
Copy link

dosubot bot commented Aug 17, 2023

🤖

Thank you for bringing this to our attention. You're correct in your observation that the SequentialChain implementation in async mode is reusing the same callbacks object for each chain, which is causing the callbacks to run repeatedly. This is indeed different from the sync mode, where a new callbacks object is created for each chain, thus avoiding the cascading run of callbacks at each step.

The issue lies in the placement of the line callbacks = _run_manager.get_child() in the _acall method. In the synchronous _call method, a new child callback manager is created for each chain in the loop. However, in the asynchronous _acall method, the child callback manager is created only once before the loop. This means that the same callbacks object is used for each chain in the async mode.

To fix this issue, you should move the line callbacks = _run_manager.get_child() inside the loop, just like in the synchronous _call method:

for i, chain in enumerate(self.chains):
    callbacks = _run_manager.get_child()
    outputs = await chain.acall(
        known_values, return_only_outputs=True, callbacks=callbacks
    )
    known_values.update(outputs)

This will ensure that a new callbacks object is created for each chain, avoiding the cascading run of callbacks at each step.

I will create a pull request with this fix and link it to this issue for tracking. Once the pull request is merged, this fix will be included in the next release of LangChain. Thank you for your contribution to improving LangChain.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

@vamseeyarla
Copy link
Contributor Author

vamseeyarla commented Aug 17, 2023

Here is a new unit test for this issue.

diff --git a/libs/langchain/tests/unit_tests/chains/test_sequential.py b/libs/langchain/tests/unit_tests/chains/test_sequential.py
index 12f72c6fe7d6323d3911a25b17074e7768952c11..9af44f24ab2335e3ab250755ac9f30092eff8ed1 100644
--- a/libs/langchain/tests/unit_tests/chains/test_sequential.py
+++ b/libs/langchain/tests/unit_tests/chains/test_sequential.py
@@ -3,11 +3,12 @@ from typing import Dict, List, Optional
 
 import pytest
 
-from langchain.callbacks.manager import CallbackManagerForChainRun
+from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
 from langchain.memory import ConversationBufferMemory
 from langchain.memory.simple import SimpleMemory
+from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
 
 
 class FakeChain(Chain):
@@ -37,6 +38,13 @@ class FakeChain(Chain):
             outputs[var] = f"{' '.join(variables)}foo"
         return outputs
 
+    async def _acall(
+        self,
+        inputs: Dict[str, str],
+        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+    ) -> Dict[str, str]:
+        return self._call(inputs, run_manager)
+
 
 def test_sequential_usage_single_inputs() -> None:
     """Test sequential on single input chains."""
@@ -165,6 +173,42 @@ def test_simple_sequential_functionality() -> None:
     assert output == expected_output
 
 
+def test_simple_sequential_functionality_with_callbacks() -> None:
+    """Test simple sequential functionality."""
+    handler_1 = FakeCallbackHandler()
+    handler_2 = FakeCallbackHandler()
+    handler_3 = FakeCallbackHandler()
+    chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1])
+    chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2])
+    chain_3 = FakeChain(input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3])
+    chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3])
+    output = chain({"input": "123"})
+    expected_output = {"output": "123foofoofoo", "input": "123"}
+    assert output == expected_output
+    for handler in [handler_1, handler_2, handler_3]:
+        assert handler.starts == 1
+        assert handler.ends == 1
+        assert handler.errors == 0
+
+
+@pytest.mark.asyncio
+async def test_simple_sequential_functionality_with_callbacks_async() -> None:
+    """Test simple sequential functionality."""
+    handler_1 = FakeCallbackHandler()
+    handler_2 = FakeCallbackHandler()
+    handler_3 = FakeCallbackHandler()
+    chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1])
+    chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2])
+    chain_3 = FakeChain(input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3])
+    chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3])
+    output = await chain.ainvoke({"input": "123"})
+    expected_output = {"output": "123foofoofoo", "input": "123"}
+    assert output == expected_output
+    for handler in [handler_1, handler_2, handler_3]:
+        assert handler.starts == 1
+        assert handler.ends == 1
+        assert handler.errors == 0
+
 def test_multi_input_errors() -> None:
     """Test simple sequential errors if multiple input variables are expected."""
     chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])

baskaryan pushed a commit that referenced this issue Aug 18, 2023
… async mode (#9452)

Issue: #9401

In the Async mode, SequentialChain implementation seems to run the same
callbacks over and over since it is re-using the same callbacks object.

Langchain version: 0.0.264, master

The implementation of this aysnc route differs from the sync route and
sync approach follows the right pattern of generating a new callbacks
object instead of re-using the old one and thus avoiding the cascading
run of callbacks at each step.

Async mode:
```
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        ...
        for i, chain in enumerate(self.chains):
            _input = await chain.arun(_input, callbacks=callbacks)
            ...
```

Regular mode:
```
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        for i, chain in enumerate(self.chains):
            _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
            ...
```

Notice how we are reusing the callbacks object in the Async code which
will have a cascading effect as we run through the chain. It runs the
same callbacks over and over resulting in issues.

Solution:
Define the async function in the same pattern as the regular one and
added tests.
---------

Co-authored-by: vamsee_yarlagadda <vamsee.y@airbnb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🤖:bug Related to a bug, vulnerability, unexpected error with an existing feature Ɑ: models Related to LLMs or chat model modules
Projects
None yet
Development

No branches or pull requests

2 participants