Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue: Stream a response from LangChain's OpenAI with Pyton Flask API #4945

Closed
zigax1 opened this issue May 18, 2023 · 31 comments
Closed

Issue: Stream a response from LangChain's OpenAI with Pyton Flask API #4945

zigax1 opened this issue May 18, 2023 · 31 comments

Comments

@zigax1
Copy link

zigax1 commented May 18, 2023

Issue you'd like to raise.

I am using Python Flask app for chat over data. So in the console I am getting streamable response directly from the OpenAI since I can enable streming with a flag streaming=True.

The problem is, that I can't “forward” the stream or “show” the strem than in my API call.

Code for the processing OpenAI and chain is:

def askQuestion(self, collection_id, question):
        collection_name = "collection-" + str(collection_id)
        self.llm = ChatOpenAI(model_name=self.model_name, temperature=self.temperature, openai_api_key=os.environ.get('OPENAI_API_KEY'), streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,  output_key='answer')
        
        chroma_Vectorstore = Chroma(collection_name=collection_name, embedding_function=self.embeddingsOpenAi, client=self.chroma_client)


        self.chain = ConversationalRetrievalChain.from_llm(self.llm, chroma_Vectorstore.as_retriever(similarity_search_with_score=True),
                                                            return_source_documents=True,verbose=VERBOSE, 
                                                            memory=self.memory)
        

        result = self.chain({"question": question})
        
        res_dict = {
            "answer": result["answer"],
        }

        res_dict["source_documents"] = []

        for source in result["source_documents"]:
            res_dict["source_documents"].append({
                "page_content": source.page_content,
                "metadata":  source.metadata
            })

        return res_dict`

and the API route code:

def stream(collection_id, question):
    completion = document_thread.askQuestion(collection_id, question)
    for line in completion:
        yield 'data: %s\n\n' % line

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
@stream_with_context
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return Response(stream(question), mimetype='text/event-stream')

I am testing my endpoint with curl and I am passing flag -N to the curl, so I should get the streamable response, if it is possible.

When I make API call first the endpoint is waiting to process the data (I can see in my terminal in VS code the streamable answer) and when finished, I get everything displayed in one go.

Thanks

Suggestion:

No response

@AvikantSrivastava
Copy link

You could use the stream_with_context function and pass in the stream generator stream
https://flask.palletsprojects.com/en/2.1.x/patterns/streaming/

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return app.response_class(stream_with_context(stream(question)))

@zigax1
Copy link
Author

zigax1 commented May 20, 2023

You could use the stream_with_context function and pass in the stream generator stream https://flask.palletsprojects.com/en/2.1.x/patterns/streaming/

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return app.response_class(stream_with_context(stream(question)))

Sadly it doesn't work and I did exactly as you told me.

@sunwooz
Copy link

sunwooz commented May 25, 2023

I'm also wondering how this is done. Tried stream_template, stream_with_context, and my server only sends the response once it is done loading and not while it is streaming. I also tried different callback handlers to no avail.

@AvikantSrivastava
Copy link

@agola11 can you answer this?

I tried doing the same in FastAPI, it did not work. Raised an Issue #5296

@zigax1
Copy link
Author

zigax1 commented May 27, 2023

I am still playing around and trying to solve it, but without any success.

@agola11
@hwchase17
@AvikantSrivastava

For now, my code looks like this:


class MyCustomHandler(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        yield token

class DocumentThread:

    def askQuestion(self, collection_id, question):
        collection_name = "collection-" + str(collection_id)
        self.llm = ChatOpenAI(model_name=self.model_name, temperature=self.temperature, openai_api_key=os.environ.get('OPENAI_API_KEY'), streaming=True, callback_manager=CallbackManager([MyCustomHandler()]))
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,  output_key='answer')
        
        chroma_Vectorstore = Chroma(collection_name=collection_name, embedding_function=self.embeddingsOpenAi, client=self.chroma_client)

   
        self.chain = ConversationalRetrievalChain.from_llm(self.llm, chroma_Vectorstore.as_retriever(similarity_search_with_score=True),
                                                            return_source_documents=True,verbose=VERBOSE, 
                                                            memory=self.memory)
        
        result = self.chain({"question": question})
        
        res_dict = {
            "answer": result["answer"],
        }

        res_dict["source_documents"] = []

        for source in result["source_documents"]:
            res_dict["source_documents"].append({
                "page_content": source.page_content,
                "metadata":  source.metadata
            })

        return res_dict
    and endpoint definition:
@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]

    def generate_tokens(question):  
        result = document_thread.askQuestion(collection_id, question)
        for token in result['answer']:
            yield token

    return Response(stream_with_context(generate_tokens(question)), mimetype='text/event-stream')


@longmans
Copy link

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

@zigax1
Copy link
Author

zigax1 commented May 29, 2023

Switched from Flask to FastAPI.. Moved to: #5409

@mziru
Copy link
Contributor

mziru commented May 29, 2023

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

working on a similar implementation but can't get it to work.
would you mind sharing how you're reading the queue from the other thread?

@mziru
Copy link
Contributor

mziru commented May 29, 2023

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

working on a similar implementation but can't get it to work. would you mind sharing how you're reading the queue from the other thread?

wait, nevermind, got it to work! thanks for the first answer.

@qixiang-mft
Copy link

qixiang-mft commented Jun 1, 2023

chain.apply doesn't return generator for synchronize function call, it's make it hard to streaming output ..
why don't use the asyncio api aapply , which make it possible to do token output

should care about the AsyncIteratorCallbackHandler , it will stop iterator when stream completing, need to count the rest tokens and return as last data event

@varunsinghal
Copy link

With the usage of threading and callback we can have a streaming response from flask API.

In flask API, you may create a queue to register tokens through langchain's callback.

class StreamingHandler(BaseCallbackHandler):
    ...

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put(token)

You may get tokens from the same queue in your flask route.

from flask import Response, stream_with_context
import threading 

@app.route(....):
def stream_output():
   q = Queue()
   
   def generate(rq: Queue):
      ...
      # add your logic to prevent while loop
      # to run indefinitely  
      while( ...):
          yield rq.get()
   
   callback_fn = StreamingHandler(q)
   
   threading.Thread(target= askQuestion, args=(collection_id, question, callback_fn))
   return Response(stream_with_context(generate(q))

In your langchain's ChatOpenAI add the above custom callback StreamingHandler.

self.llm = ChatOpenAI(
  model_name=self.model_name, 
  temperature=self.temperature, 
  openai_api_key=os.environ.get('OPENAI_API_KEY'), 
  streaming=True, 
  callback=[callback_fn,]
)

For reference:

https://python.langchain.com/en/latest/modules/callbacks/getting_started.html#creating-a-custom-handler
https://flask.palletsprojects.com/en/2.3.x/patterns/streaming/#streaming-with-context

@francisjervis
Copy link

@varunsinghal @longmans nice work, I am building Flask-Langchain & want to include streaming functionality. Have you tested this approach with multiple concurrent requests?

Would be fantastic if one of you could open a PR to add an extension-based callback handler and route class (or decorator?) to handle streaming responses to the Flask-Langchain project - this probably isn't functionality that belongs in the main Langchain library as it is Flask-specific.

@VionaWang
Copy link

@varunsinghal Thank you for the great answer! Could you elaborate more on the implementation of your method? I couldn't reproduce a code with your method to get it to work. Thanks in advance!

@riccardolinares
Copy link

Working on the same problem. No success at the moment... @varunsinghal I am not getting your solution tbh

@varunsinghal
Copy link

hi @VionaWang @riccardolinares can you please share your code samples, so that I can make suggestions/debug on what could be going wrong over there?

@manuel-84
Copy link

manuel-84 commented Jun 23, 2023

With the usage of threading and callback we can have a streaming response from flask API.

managed to get streaming work BUT with a ConversationalRetrievalChain it's printing also the condensed question before the answer, and I tried to replace BaseCallbackHandler with FinalStreamingStdOutCallbackHandler but it's the same

@manuel-84
Copy link

manuel-84 commented Jun 23, 2023

solved in a very hacky way (of course can be improved), if the prompt comes from the condensator then the streaming will be discarded - so the final streamed tokens will contain only the answer without condensed question


class QueueCallback(BaseCallbackHandler):
    def __init__(self, q):
        self.q = q
        self.discard = False
    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> Any:
        if prompts[0].__contains__('Standalone question'):
            self.discard = True
        else:
            self.discard = False
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        if not self.discard:
            self.q.put(token)
    def on_llm_end(self, *args, **kwargs: Any) -> None:
        return self.q.empty()

@JoAmps
Copy link

JoAmps commented Jun 28, 2023

@stream_with_context

How did you make it work, been bugging me
and also, where do you import the LLMResult from

@JoAmps
Copy link

JoAmps commented Jun 28, 2023

With the usage of threading and callback we can have a streaming response from flask API.

In flask API, you may create a queue to register tokens through langchain's callback.

class StreamingHandler(BaseCallbackHandler):
    ...

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put(token)

You may get tokens from the same queue in your flask route.

from flask import Response, stream_with_context
import threading 

@app.route(....):
def stream_output():
   q = Queue()
   
   def generate(rq: Queue):
      ...
      # add your logic to prevent while loop
      # to run indefinitely  
      while( ...):
          yield rq.get()
   
   callback_fn = StreamingHandler(q)
   
   threading.Thread(target= askQuestion, args=(collection_id, question, callback_fn))
   return Response(stream_with_context(generate(q))

In your langchain's ChatOpenAI add the above custom callback StreamingHandler.

self.llm = ChatOpenAI(
  model_name=self.model_name, 
  temperature=self.temperature, 
  openai_api_key=os.environ.get('OPENAI_API_KEY'), 
  streaming=True, 
  callback=[callback_fn,]
)

For reference:

https://python.langchain.com/en/latest/modules/callbacks/getting_started.html#creating-a-custom-handler https://flask.palletsprojects.com/en/2.3.x/patterns/streaming/#streaming-with-context

It would be great, if you showed the whole code

@youssef595
Copy link

Please i can't see the code of the working solution, can you please show it ?

@usersina
Copy link

Here's a full minimal working example, taking from all of the answers above (with typings, modularity using Blueprints and minimal error handling as a bonus):

To explain how it all works:

  1. The controller endpoint defines an ask_question function. This function is responsible for starting the generation process in a separate thread as soon as we hit the endpoint. Note how it uses a custom callback of type StreamingStdOutCallbackHandlerYield and sets streaming=True. It delegates all of its streaming behavior to the custom class and uses a q variable that I will talk about shortly.
  2. The return type of the controller is a Response that runs the generate function. This function is the one that's actually "listening" for the streamable LLM output and yielding the result back as a stream to the HTTP caller as soon as it gets it.
  3. The way it all works is thanks to the StreamingStdOutCallbackHandlerYield. It basically writes all LLM output as soon as it comes back from OpenAI. Note how it writes it back to a Queue object that's created at controller level.
  4. Finally, see how I stop the generate function as soon as I get a special literal named STOP_ITEM. This is returned from the custom callback when the on_llm_end is executed, or when we have an error (on_llm_error). In which case, I also return the error just before returning the STOP_ITEM.

routes/stream.py

import os
import threading
from queue import Queue

from flask import Response

from utils.streaming import StreamingStdOutCallbackHandlerYield, generate

page = Blueprint(os.path.splitext(os.path.basename(__file__))[0], __name__)

# Define the expected input type
class Input(TypedDict):
    prompt: str

@page.route("/", methods=["POST"])
def stream_text() -> Response:
    data: Input = request.get_json()

    prompt = data["prompt"]
    q = Queue()

    def ask_question(callback_fn: StreamingStdOutCallbackHandlerYield):
        # Note that a try/catch is not needed here. Callback takes care of all errors in `on_llm_error`
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt=prompt)

    callback_fn = StreamingStdOutCallbackHandlerYield(q)
    threading.Thread(target=ask_question, args=(callback_fn,)).start()
    return Response(generate(q), mimetype="text/event-stream")

utils/streaming.py

import queue
from typing import Any, Dict, List, Union

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult

STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""


class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    """
    This is a callback handler that yields the tokens as they are generated.
    For a usage example, see the :func:`generate` function below.
    """

    q: queue.Queue
    """
    The queue to write the tokens to as they are generated.
    """

    def __init__(self, q: queue.Queue) -> None:
        """
        Initialize the callback handler.
        q: The queue to write the tokens to as they are generated.
        """
        super().__init__()
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with self.q.mutex:
            self.q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        # Writes to stdout
        # sys.stdout.write(token)
        # sys.stdout.flush()
        # Pass the token to the generator
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        self.q.put(STOP_ITEM)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        self.q.put("%s: %s" % (type(error).__name__, str(error)))
        self.q.put(STOP_ITEM)


def generate(rq: queue.Queue):
    """
    This is a generator that yields the items in the queue until it reaches the stop item.

    Usage example:
    ```
    def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt="Write a poem about a tree.")

    @app.route("/", methods=["GET"])
    def generate_output():
        q = Queue()
        callback_fn = StreamingStdOutCallbackHandlerYield(q)
        threading.Thread(target=askQuestion, args=(callback_fn,)).start()
        return Response(generate(q), mimetype="text/event-stream")
    ```
    """
    while True:
        result: str = rq.get()
        if result == STOP_ITEM or result is None:
            break
        yield result
Complete folder structure

Here's a the working tree, if you're struggling where the files are located:

.
├── README.md
├── requirements.txt
└── src
    ├── main.py
    ├── routes
    │   └── stream.py
    └── utils
        └── streaming.py

main.py:

from dotenv import load_dotenv
from flask import Flask
from flask_cors import CORS

from routes.stream import page as stream_route

# Load environment variables
load_dotenv(
    dotenv_path=".env",  # Relative to where the script is running from
)

app = Flask(__name__)
# See https://github.com/corydolphin/flask-cors/issues/257
app.url_map.strict_slashes = False

CORS(app)

app.register_blueprint(stream_route, url_prefix="/api/chat")

if __name__ == "__main__":
    app.run()

I will soon follow with a full repository (probably)

@usersina
Copy link

My previous solution is a performance killer, so here's a better, more concise one:

import asyncio
import json

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains import ConversationChain
from langchain.llms.openai import OpenAI

@page.route("/general", methods=["POST"])
async def general_chat():
    try:
        memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))
            async for chunk in handler.aiter():
                yield f"data: {json.dumps({'content': chunk, 'tokens': 0})}\n\n"

        return ask_question_async(), {"Content-Type": "text/event-stream"}

    except Exception as e:
        return {"error": "{}: {}".format(type(e).__name__, str(e))}, 500

Note that AsyncIteratorCallbackHandler does not work with agents yet. See this issue for more details.

@girithodu
Copy link

memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))

What led you to choose conversation.apredict instead of the standard method of directly passing the user query to created chain?

@usersina
Copy link

memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))

What led you to choose conversation.apredict instead of the standard method of directly passing the user query to created chain?

Because apredict is asynchronous. In fact you might also be able to directly call arun IIRC. In the end, all methods explicitly make a Chain.__call__ call. I cannot say much about performance without any bench-marking though...

@JoAmps
Copy link

JoAmps commented Jan 31, 2024

@usersina

How about doing this using Retrieval chain, trying to but getting errrors

@JoAmps
Copy link

JoAmps commented Jan 31, 2024

@usersina

How about doing this using Retrieval chain, trying to but getting errrors

My previous solution is a performance killer, so here's a better, more concise one:

import asyncio
import json

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains import ConversationChain
from langchain.llms.openai import OpenAI

@page.route("/general", methods=["POST"])
async def general_chat():
    try:
        memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))
            async for chunk in handler.aiter():
                yield f"data: {json.dumps({'content': chunk, 'tokens': 0})}\n\n"

        return ask_question_async(), {"Content-Type": "text/event-stream"}

    except Exception as e:
        return {"error": "{}: {}".format(type(e).__name__, str(e))}, 500

Note that AsyncIteratorCallbackHandler does not work with agents yet. See this issue for more details.

How about doing this using Retrieval chain, trying to but getting errors

@usersina
Copy link

usersina commented Feb 6, 2024

@JoAmps I'm not too sure without seeing any code, but I really recommend you switch over to LCEL, there's so much you can customize and implement that way, especially as you move closer to production.

@ElderBlade
Copy link

@usersina thanks for providing your code. I've tried what you recommended in your comment, and it works except I do not get the final output from the agent. I get the chain thought process returned in my Flask app, but it stops short of returning the final answer. What am I missing?

streaming.py

import sys
import queue
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    """
    This is a callback handler that yields the tokens as they are generated.
    For a usage example, see the :func:`generate` function below.
    """

    q: queue.Queue
    """
    The queue to write the tokens to as they are generated.
    """

    def __init__(self, q: queue.Queue) -> None:
        """
        Initialize the callback handler.
        q: The queue to write the tokens to as they are generated.
        """
        super().__init__()
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with self.q.mutex:
            self.q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        # Writes to stdout
        sys.stdout.write(token)
        sys.stdout.flush()
        # Pass the token to the generator
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        sys.stdout.write("THE END!!!")
        self.q.put(response.output)
        self.q.put(STOP_ITEM)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        sys.stdout.write(f"LLM Error: {error}\n")
        self.q.put("%s: %s" % (type(error).__name__, str(error)))
        self.q.put(STOP_ITEM)
    
    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
        """Print out that we are entering a chain."""
        self.q.put("Entering the chain...\n\n")

    def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
        sys.stdout.write(f"Tool: {serialized['name']}\n")
        self.q.put(f"Tool: {serialized['name']}\n")
    
    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        sys.stdout.write(f"{action.log}\n")
        self.q.put(f"{action.log}\n")

def generate(rq: queue.Queue):
    """
    This is a generator that yields the items in the queue until it reaches the stop item.

    Usage example:
    ```
    def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt="Write a poem about a tree.")

    @app.route("/", methods=["GET"])
    def generate_output():
        q = Queue()
        callback_fn = StreamingStdOutCallbackHandlerYield(q)
        threading.Thread(target=askQuestion, args=(callback_fn,)).start()
        return Response(generate(q), mimetype="text/event-stream")
    ```
    """
    while True:
        result: str = rq.get()
        if result == STOP_ITEM or result is None:
            break
        yield result
            

routes.py

@app.route('/chat', methods=['POST'])
@auth.secured()
def chat():
    message = request.json['messages']
    
    chat_message_history = CustomChatMessageHistory(
    session_id=session['conversation_id'], connection_string="sqlite:///chat_history.db"
    )
    
    q = Queue()
    callback_fn = StreamingStdOutCallbackHandlerYield(q)

    def ask_question(callback_fn: StreamingStdOutCallbackHandlerYield):
        
        # Callback manager
        cb_manager = CallbackManager(handlers=[callback_fn])
        
        ## SQLDbAgent is a custom Tool class created to Q&A over a MS SQL Database
        sql_search = SQLSearchAgent(llm=llm, k=30, callback_manager=cb_manager, return_direct=True)

        ## ChatGPTTool is a custom Tool class created to talk to ChatGPT knowledge
        chatgpt_search = ChatGPTTool(llm=llm, callback_manager=cb_manager, return_direct=True)
        tools = [sql_search, chatgpt_search]

        agent = ConversationalChatAgent.from_llm_and_tools(llm=llm, tools=tools, system_message=CUSTOM_CHATBOT_PREFIX, human_message=CUSTOM_CHATBOT_SUFFIX)
        memory = ConversationBufferWindowMemory(memory_key="chat_history", return_messages=True, k=10, chat_memory=chat_message_history)
        brain_agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory, handle_parsing_errors=True, streaming=True)
        return brain_agent_executor.run(message['content'])
   
    threading.Thread(target=ask_question, args=(callback_fn,)).start()
    return Response(generate(q), mimetype="text/event-stream")

@usersina
Copy link

usersina commented Feb 20, 2024

@mmoore7 - there might have been a change to the stop condition, that or the tool/train of thought end event is getting called. I cannot say for sure since I have long moved from Flask and classic LangChain to LangChain Expression Language and FastAPI for better streaming.

@eyurtsev
Copy link
Collaborator

eyurtsev commented Mar 7, 2024

LangServe has a number of examples that get streaming working out of the box with FastAPI.

https://github.com/langchain-ai/langserve/tree/main?tab=readme-ov-file#examples

We strongly recommed using LCEL, and depending on what you're doing either using the astream API or the astream_events API.

I am marking this issue as closed as there's enough examples and documentation for folks to solve this without much difficulty.

LangServe will provide streaming that will be available to the RemoteRunnable js client in just a few lines of code!

@eyurtsev eyurtsev closed this as completed Mar 7, 2024
@kishykumar
Copy link

Flask needs an equivalent of StreamingResponse that FastAPI has. I too switched to FastAPI since the support in Flask is lacking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests