Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group chat and retrieve agent example #227

Merged
merged 17 commits into from
Oct 17, 2023
Merged
19 changes: 16 additions & 3 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
- n_results (Optional, int): the number of results to be retrieved. Useful in group chat. Will be overridden by the same
parameter passed to `generate_init_message`. Default is 20.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).

Example of overriding retrieve_docs:
Expand Down Expand Up @@ -175,6 +177,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self.n_results = self._retrieve_config.get("n_results", 20)
self._initchat_from_this = False # assume this is not the agent to init chat
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
Expand Down Expand Up @@ -297,6 +301,14 @@ def _generate_retrieve_user_reply(
messages = self._oai_messages[sender]
message = messages[-1]
update_context_case1, update_context_case2 = self._check_update_context(message)
if not self._initchat_from_this and not hasattr(self, "problem"):
# the first time the agent is called in a group chat
message = self.generate_init_message(message.get("content", ""), n_results=self.n_results)
self._initchat_from_this = False # reset to False as the value is changed to True in generate_init_message
return True, message
elif not self._initchat_from_this and hasattr(self, "problem"):
# the agent is called in a group chat and the problem is already set
update_context_case1 = True
thinkall marked this conversation as resolved.
Show resolved Hide resolved
if (update_context_case1 or update_context_case2) and self.update_context:
print(colored("Updating context and resetting conversation.", "green"), flush=True)
# extract the first sentence in the response as the intermediate answer
Expand Down Expand Up @@ -380,21 +392,22 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._results = results
print("doc_ids: ", results["ids"])

def generate_init_message(self, problem: str, n_results: int = 20, search_string: str = ""):
def generate_init_message(self, problem: str, n_results: int = None, search_string: str = ""):
"""Generate an initial message with the given problem and prompt.

Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved.
n_results (int): the number of results to be retrieved. Default is None, will use the value set in retrieve_config.
search_string (str): only docs containing this string will be retrieved.

Returns:
str: the generated prompt ready to be sent to the assistant agent.
"""
self._reset()
self._initchat_from_this = True
self.retrieve_docs(problem, n_results, search_string)
self.problem = problem
self.n_results = n_results
self.n_results = n_results if n_results is not None else self.n_results
doc_contents = self._get_context(self._results)
message = self._generate_message(doc_contents, self._task)
return message
Expand Down
Loading
Loading