Skip to content

Commit

Permalink
General refactor (#53)
Browse files Browse the repository at this point in the history
* streaming and refactor

* remove old code

* save full response

* fix thought not being passed

* add tags

* commonize the init

* Refactor stream and save functionality to use async iterators

- Refactored the `stream_and_save` function to use a new class, `Streamable`, which wraps the async iterator and saves the content on completion via a callback.
- Updated the `think` and `respond` methods in the `BloomChain` class to return an instance of `Streamable` instead of string content.
- Modified the `stream_and_save` function to iterate through the `thought_iterator` and `response_iterator` using `async for`, and update the placeholders with the received content.
- Updated the `BloomChain` class to utilize the `Streamable` wrapper for the `think` and `respond` methods.
- Updated the `stream_and_save` function to pass the correct arguments to the `think` and `respond` methods.
  • Loading branch information
hyusap authored Aug 11, 2023
1 parent b6a19c4 commit a70f3fe
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 74 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ RUN addgroup --system app && adduser --system --group app
USER app

COPY agent/ agent/
COPY common/ common/

COPY bot/ bot/

COPY www/ www/
Expand Down
94 changes: 65 additions & 29 deletions agent/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
from langchain.prompts import (
SystemMessagePromptTemplate,
)
from langchain.prompts import load_prompt
from langchain.schema import AIMessage, HumanMessage
from langchain.prompts import load_prompt, ChatPromptTemplate
from langchain.schema import AIMessage, HumanMessage, BaseMessage
from dotenv import load_dotenv

from collections.abc import AsyncIterator, Awaitable
from typing import Any, List
import asyncio


load_dotenv()

SYSTEM_THOUGHT = load_prompt(os.path.join(os.path.dirname(__file__), 'prompts/thought.yaml'))
Expand All @@ -27,7 +32,7 @@ def restart(self) -> None:

class BloomChain:
"Wrapper class for encapsulating the multiple different chains used in reasoning for the tutor's thoughts"
def __init__(self, llm: ChatOpenAI, verbose: bool = False) -> None:
def __init__(self, llm: ChatOpenAI = ChatOpenAI(model_name = "gpt-4", temperature=1.2), verbose: bool = True) -> None:
self.llm = llm
self.verbose = verbose

Expand All @@ -36,50 +41,81 @@ def __init__(self, llm: ChatOpenAI, verbose: bool = False) -> None:
self.system_response = SystemMessagePromptTemplate(prompt=SYSTEM_RESPONSE)


async def think(self, thought_memory: ChatMessageHistory, input: str) -> str:
def think(self, thought_memory: ChatMessageHistory, input: str):
"""Generate Bloom's thought on the user."""

# load message history
messages = [self.system_thought.format(), *thought_memory.messages, HumanMessage(content=input)]
thought_message = await self.llm.apredict_messages(messages)
thought_prompt = ChatPromptTemplate.from_messages([
self.system_thought,
*thought_memory.messages,
HumanMessage(content=input)
])
chain = thought_prompt | self.llm

# update chat memory
thought_memory.add_message(HumanMessage(content=input))
thought_memory.add_message(thought_message) # apredict_messages returns AIMessage so can add directly

return thought_message.content

return Streamable(
chain.astream({}, {"tags": ["thought"]}),
lambda thought: thought_memory.add_message(AIMessage(content=thought))
)

async def respond(self, response_memory: ChatMessageHistory, thought: str, input: str) -> str:
def respond(self, response_memory: ChatMessageHistory, thought: str, input: str):
"""Generate Bloom's response to the user."""

# load message history
messages = [self.system_response.format(thought=thought), *response_memory.messages, HumanMessage(content=input)]
response_message = await self.llm.apredict_messages(messages)
response_prompt = ChatPromptTemplate.from_messages([
self.system_response,
*response_memory.messages,
HumanMessage(content=input)
])
chain = response_prompt | self.llm

# update chat memory
response_memory.add_message(HumanMessage(content=input))
response_memory.add_message(response_message) # apredict_messages returns AIMessage so can add directly

return response_message.content
return Streamable(
chain.astream({ "thought": thought }, {"tags": ["response"]}),
lambda response: response_memory.add_message(AIMessage(content=response))
)



async def chat(self, cache: ConversationCache, inp: str ) -> tuple[str, str]:
thought = await self.think(cache.thought_memory, inp)
response = await self.respond(cache.response_memory, thought, inp)
return thought, response
thought_iterator = self.think(cache.thought_memory, inp)
thought = await thought_iterator()


def load_chains() -> BloomChain:
"""Logic for loading the chain you want to use should go here."""
llm = ChatOpenAI(model_name = "gpt-4", temperature=1.2)
response_iterator = self.respond(cache.response_memory, thought, inp)
response = await response_iterator()

return thought, response


# define chain
bloom_chain = BloomChain(
llm=llm,
verbose=True
)

return bloom_chain

class Streamable:
"A async iterator wrapper for langchain streams that saves on completion via callback"

def __init__(self, iterator: AsyncIterator[BaseMessage], callback):
self.iterator = iterator
self.callback = callback
# self.content: List[Awaitable[BaseMessage]] = []
self.content = ""

def __aiter__(self):
return self

async def __anext__(self):
try:
data = await self.iterator.__anext__()
self.content += data.content
return self.content
except StopAsyncIteration as e:
self.callback(self.content)
raise StopAsyncIteration
except Exception as e:
raise e

async def __call__(self):
async for _ in self:
pass
return self.content

22 changes: 6 additions & 16 deletions bot/app.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import os
import discord
from dotenv import load_dotenv
from agent.chain import load_chains
from agent.chain import BloomChain
from agent.cache import LRUCache

load_dotenv()
token = os.environ['BOT_TOKEN']

def init():
global BLOOM_CHAIN, \
CACHE, \
THOUGHT_CHANNEL
from common import init
from dotenv import load_dotenv

CACHE = LRUCache(50)
THOUGHT_CHANNEL = os.environ["THOUGHT_CHANNEL_ID"]
BLOOM_CHAIN = load_chains()

init()
load_dotenv()
CACHE, BLOOM_CHAIN, (THOUGHT_CHANNEL, TOKEN) = init()

intents = discord.Intents.default()
intents.messages = True
Expand All @@ -29,4 +19,4 @@ def init():
bot.load_extension("bot.core")


bot.run(token)
bot.run(TOKEN)
12 changes: 12 additions & 0 deletions common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from agent.cache import LRUCache
from agent.chain import BloomChain


def init():
CACHE = LRUCache(50)
BLOOM_CHAIN = BloomChain()
THOUGHT_CHANNEL = os.environ["THOUGHT_CHANNEL_ID"]
TOKEN = os.environ['BOT_TOKEN']

return CACHE, BLOOM_CHAIN, (THOUGHT_CHANNEL, TOKEN)
61 changes: 32 additions & 29 deletions www/main.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
import os
from dotenv import load_dotenv
import streamlit as st
import time
from agent.cache import LRUCache
from agent.chain import ConversationCache, load_chains
from agent.chain import ConversationCache
import asyncio

from dotenv import load_dotenv
from common import init

def init():
global BLOOM_CHAIN, \
CACHE, \
THOUGHT_CHANNEL

CACHE = LRUCache(50)
THOUGHT_CHANNEL = os.environ["THOUGHT_CHANNEL_ID"]
BLOOM_CHAIN = load_chains()

load_dotenv()
token = os.environ['BOT_TOKEN']
CACHE, BLOOM_CHAIN, _ = init()

init()

st.set_page_config(
page_title="Bloom - Learning. Reimagined.",
Expand Down Expand Up @@ -71,12 +62,31 @@ def init():
st.markdown(message['content'])


thought, response = '', ''
async def chat_and_save(local_chain: ConversationCache, input: str) -> None:
global thought, response
bloom_chain = BLOOM_CHAIN # if local_chain.conversation_type == "discuss" else WORKSHOP_RESPONSE_CHAIN
thought, response = await bloom_chain.chat(local_chain, input)
return None
# thought, response = '', ''
# async def chat_and_save(local_chain: ConversationCache, input: str) -> None:
# global thought, response
# bloom_chain = BLOOM_CHAIN # if local_chain.conversation_type == "discuss" else WORKSHOP_RESPONSE_CHAIN
# thought, response = await bloom_chain.chat(local_chain, input)
# return None

async def stream_and_save(prompt: str) -> None:
thought_iterator = BLOOM_CHAIN.think(st.session_state.local_chain.thought_memory, prompt)

thought_placeholder = st.sidebar.empty()
async for thought in thought_iterator:
thought_placeholder.markdown(thought)

response_iterator = BLOOM_CHAIN.respond(st.session_state.local_chain.response_memory, thought_iterator.content, prompt)
with st.chat_message('assistant', avatar="https://bloombot.ai/wp-content/uploads/2023/02/bloom-fav-icon@10x-200x200.png"):
response_placeholder = st.empty()
async for response in response_iterator:
response_placeholder.markdown(response)

st.session_state.messages.append({
"role": "assistant",
"content": response_iterator.content
})



if prompt := st.chat_input("hello!"):
Expand All @@ -86,16 +96,9 @@ async def chat_and_save(local_chain: ConversationCache, input: str) -> None:
'role': 'user',
'content': prompt
})
with st.chat_message('assistant', avatar="https://bloombot.ai/wp-content/uploads/2023/02/bloom-fav-icon@10x-200x200.png"):
with st.spinner("Thinking..."):
asyncio.run(chat_and_save(st.session_state.local_chain, prompt))
st.markdown(response)
asyncio.run(stream_and_save(prompt))



st.sidebar.write(thought)
st.sidebar.divider()

st.session_state.messages.append({
"role": "assistant",
"content": response
})

0 comments on commit a70f3fe

Please sign in to comment.