Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-turn capability for agent #1248

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
137 changes: 89 additions & 48 deletions comps/agent/src/README.md

Large diffs are not rendered by default.

67 changes: 41 additions & 26 deletions comps/agent/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from comps.agent.src.integrations.agent import instantiate_agent
from comps.agent.src.integrations.global_var import assistants_global_kv, threads_global_kv
from comps.agent.src.integrations.thread import instantiate_thread_memory, thread_completion_callback
from comps.agent.src.integrations.utils import assemble_store_messages, get_args
from comps.agent.src.integrations.utils import assemble_store_messages, get_args, get_latest_human_message_from_store
from comps.cores.proto.api_protocol import (
AssistantsObject,
ChatCompletionRequest,
Expand All @@ -40,7 +40,7 @@

logger.info("========initiating agent============")
logger.info(f"args: {args}")
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
agent_inst = instantiate_agent(args)


class AgentCompletionRequest(ChatCompletionRequest):
Expand Down Expand Up @@ -76,7 +76,7 @@ async def llm_generate(input: AgentCompletionRequest):
if isinstance(input.messages, str):
messages = input.messages
else:
# TODO: need handle multi-turn messages
# last user message
messages = input.messages[-1]["content"]

# 2. prepare the input for the agent
Expand All @@ -90,7 +90,6 @@ async def llm_generate(input: AgentCompletionRequest):
else:
logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(messages, config)
logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=messages)


Expand All @@ -100,14 +99,14 @@ class RedisConfig(BaseModel):

class AgentConfig(BaseModel):
stream: Optional[bool] = False
agent_name: Optional[str] = "OPEA_Default_Agent"
agent_name: Optional[str] = "OPEA_Agent"
strategy: Optional[str] = "react_llama"
role_description: Optional[str] = "LLM enhanced agent"
role_description: Optional[str] = "AI assistant"
tools: Optional[str] = None
recursion_limit: Optional[int] = 5

model: Optional[str] = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_engine: Optional[str] = None
model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct"
llm_engine: Optional[str] = "vllm"
llm_endpoint_url: Optional[str] = None
max_new_tokens: Optional[int] = 1024
top_k: Optional[int] = 10
Expand All @@ -117,10 +116,14 @@ class AgentConfig(BaseModel):
return_full_text: Optional[bool] = False
custom_prompt: Optional[str] = None

# short/long term memory
with_memory: Optional[bool] = False
# persistence
with_store: Optional[bool] = False
# # short/long term memory
with_memory: Optional[bool] = True
# agent memory config
# chat_completion api: only supports volatile memory
# assistants api: supports volatile and persistent memory
# volatile: in-memory checkpointer - MemorySaver()
# persistent: redis store
memory_type: Optional[str] = "volatile" # choices: volatile, persistent
store_config: Optional[RedisConfig] = None

timeout: Optional[int] = 60
Expand All @@ -147,18 +150,17 @@ class CreateAssistant(CreateAssistantsRequest):
)
def create_assistants(input: CreateAssistant):
# 1. initialize the agent
agent_inst = instantiate_agent(
input.agent_config, input.agent_config.strategy, with_memory=input.agent_config.with_memory
)
print("@@@ Initializing agent with config: ", input.agent_config)
agent_inst = instantiate_agent(input.agent_config)
assistant_id = agent_inst.id
created_at = int(datetime.now().timestamp())
with assistants_global_kv as g_assistants:
g_assistants[assistant_id] = (agent_inst, created_at)
logger.info(f"Record assistant inst {assistant_id} in global KV")

if input.agent_config.with_store:
if input.agent_config.memory_type == "persistent":
logger.info("Save Agent Config to database")
agent_inst.with_store = input.agent_config.with_store
# agent_inst.memory_type = input.agent_config.memory_type
print(input)
global db_client
if db_client is None:
Expand All @@ -172,6 +174,7 @@ def create_assistants(input: CreateAssistant):
return AssistantsObject(
id=assistant_id,
created_at=created_at,
model=input.agent_config.model,
)


Expand Down Expand Up @@ -211,7 +214,7 @@ def create_messages(thread_id, input: CreateMessagesRequest):
if isinstance(input.content, str):
query = input.content
else:
query = input.content[-1]["text"]
query = input.content[-1]["text"] # content is a list of MessageContent
msg_id, created_at = thread_inst.add_query(query)

structured_content = MessageContent(text=query)
Expand All @@ -224,15 +227,18 @@ def create_messages(thread_id, input: CreateMessagesRequest):
assistant_id=input.assistant_id,
)

# save messages using assistant_id as key
# save messages using assistant_id_thread_id as key
if input.assistant_id is not None:
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[input.assistant_id]
if agent_inst.with_store:
logger.info(f"Save Agent Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
if agent_inst.memory_type == "persistent":
logger.info(f"Save Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
# if with store, db_client initialized already
global db_client
db_client.put(msg_id, message.model_dump_json(), input.assistant_id)
namespace = f"{input.assistant_id}_{thread_id}"
# put(key: str, val: dict, collection: str = DEFAULT_COLLECTION)
db_client.put(msg_id, message.model_dump_json(), namespace)
logger.info(f"@@@ Save message to db: {msg_id}, {message.model_dump_json()}, {namespace}")

return message

Expand All @@ -254,15 +260,24 @@ def create_run(thread_id, input: CreateRunResponse):
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[assistant_id]

config = {"recursion_limit": args.recursion_limit}
config = {
"recursion_limit": args.recursion_limit,
"configurable": {"session_id": thread_id, "thread_id": thread_id, "user_id": assistant_id},
}

if agent_inst.with_store:
# assemble multi-turn messages
if agent_inst.memory_type == "persistent":
global db_client
input_query = assemble_store_messages(db_client.get_all(assistant_id))
namespace = f"{assistant_id}_{thread_id}"
# get the latest human message from store in the namespace
input_query = get_latest_human_message_from_store(db_client, namespace)
print("@@@@ Input_query from store: ", input_query)
else:
input_query = thread_inst.get_query()
print("@@@@ Input_query from thread_inst: ", input_query)

print("@@@ Agent instance:")
print(agent_inst.id)
print(agent_inst.args)
try:
return StreamingResponse(
thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id),
Expand Down
8 changes: 6 additions & 2 deletions comps/agent/src/integrations/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from .storage.persistence_redis import RedisPersistence
from .utils import load_python_prompt


def instantiate_agent(args, strategy="react_langchain", with_memory=False):
def instantiate_agent(args):
strategy = args.strategy
with_memory = args.with_memory

if args.custom_prompt is not None:
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
custom_prompt = load_python_prompt(args.custom_prompt)
Expand All @@ -22,7 +26,7 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt)
return ReActAgentLlama(args, custom_prompt=custom_prompt)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

Expand Down
26 changes: 20 additions & 6 deletions comps/agent/src/integrations/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from uuid import uuid4

from langgraph.checkpoint.memory import MemorySaver

from ..storage.persistence_redis import RedisPersistence
from ..tools import get_tools_descriptions
from ..utils import adapt_custom_prompt, setup_chat_model

Expand All @@ -12,11 +15,25 @@ def __init__(self, args, local_vars=None, **kwargs) -> None:
self.llm = setup_chat_model(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
self.memory = None
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"
self.args = args
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
print(self.tools_descriptions)
print("Registered tools: ", self.tools_descriptions)

if args.with_memory:
if args.memory_type == "volatile":
self.memory_type = "volatile"
self.checkpointer = MemorySaver()
self.store = None
elif args.memory_type == "persistent":
# print("Using Redis as persistent storage: ", args.store_config.redis_uri)
self.store = RedisPersistence(args.store_config.redis_uri)
self.memory_type = "persistent"
else:
raise ValueError("Invalid memory type!")
else:
self.store = None
self.checkpointer = None

@property
def is_vllm(self):
Expand Down Expand Up @@ -60,10 +77,7 @@ async def non_streaming_run(self, query, config):
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
message.pretty_print()

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
Expand Down
Loading
Loading