diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index b24249bbe96..2abc8a9ac94 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -122,8 +122,8 @@ def __init__( - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. - This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None. + - get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb. + Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None. - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models. @@ -178,9 +178,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) - self._get_or_create = ( - self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False - ) + self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._context_max_tokens = self._max_tokens * 0.8 @@ -360,7 +358,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = n_results (int): the number of results to be retrieved. search_string (str): only docs containing this string will be retrieved. """ - if not self._collection or self._get_or_create: + if not self._collection or not self._get_or_create: print("Trying to create collection.") self._client = create_vector_db_from_dir( dir_path=self._docs_path, @@ -375,7 +373,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = custom_text_split_function=self.custom_text_split_function, ) self._collection = True - self._get_or_create = False + self._get_or_create = True results = query_vector_db( query_texts=[problem], diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index b98ba862d1a..607608f5c03 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -242,7 +242,7 @@ def create_vector_db_from_dir( db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db". collection_name (Optional, str): the name of the collection. Default is "all-my-documents". get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection - will be recreated if it already exists. + will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False. chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if diff --git a/notebook/agentchat_RetrieveChat.ipynb b/notebook/agentchat_RetrieveChat.ipynb index faee4cf2bf5..4aabc52b01e 100644 --- a/notebook/agentchat_RetrieveChat.ipynb +++ b/notebook/agentchat_RetrieveChat.ipynb @@ -212,7 +212,7 @@ " \"model\": config_list[0][\"model\"],\n", " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"embedding_model\": \"all-mpnet-base-v2\",\n", - " \"get_or_create\": False, # set to True if you want to recreate the collection\n", + " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection, but you'll need to remove the collection manually\n", " },\n", ")" ] @@ -4172,7 +4172,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" } }, "nbformat": 4,