Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Retriever API #619

Merged
merged 72 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
7b967ee
update
AyushExel Nov 10, 2023
144ed87
update
AyushExel Nov 10, 2023
42cb753
Merge branch 'main' into lance
AyushExel Nov 10, 2023
d74d043
Merge branch 'main' into lance
AyushExel Nov 11, 2023
6f705da
Merge branch 'main' into lance
AyushExel Nov 12, 2023
4baf0ae
update
AyushExel Nov 16, 2023
3f547d4
add tests
AyushExel Nov 17, 2023
77aa60f
update tests
AyushExel Nov 17, 2023
af61aeb
Merge branch 'main' into lance
AyushExel Nov 17, 2023
f400fcb
format
AyushExel Nov 17, 2023
65ad434
update
AyushExel Nov 20, 2023
99fd892
Merge branch 'main' into lance
AyushExel Nov 20, 2023
369de53
update
AyushExel Nov 20, 2023
af4b247
update
AyushExel Nov 20, 2023
5b9a43e
update
AyushExel Nov 20, 2023
9978196
update
AyushExel Nov 20, 2023
f2739cf
update
AyushExel Nov 20, 2023
ffcd207
Merge branch 'main' into lance
AyushExel Nov 20, 2023
66ca291
Merge branch 'main' into lance
AyushExel Nov 21, 2023
079886a
update
AyushExel Nov 21, 2023
6903177
update
AyushExel Nov 21, 2023
e733ce3
Merge branch 'main' into lance
AyushExel Nov 22, 2023
0800bf6
Merge branch 'main' into lance
thinkall Nov 25, 2023
8cf4cb2
update
AyushExel Nov 26, 2023
055e7d7
update
AyushExel Nov 26, 2023
efc64f8
Update autogen/agentchat/contrib/retriever/base.py
AyushExel Nov 28, 2023
4f90d31
Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py
AyushExel Nov 28, 2023
10f1b23
Update setup.py
AyushExel Nov 28, 2023
8a68640
Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py
AyushExel Nov 28, 2023
25ff5c0
Merge branch 'main' into lance
AyushExel Nov 28, 2023
1d06955
update
AyushExel Nov 30, 2023
b214aa6
Merge branch 'main' into lance
AyushExel Nov 30, 2023
01d305f
update tests
AyushExel Nov 30, 2023
4d622df
Merge remote-tracking branch 'refs/remotes/origin/lance' into lance
AyushExel Nov 30, 2023
b4cd6c4
move retrieve utils
AyushExel Nov 30, 2023
9beb1be
make qdrant work
AyushExel Nov 30, 2023
f1ccc4b
update test dir
AyushExel Nov 30, 2023
7f8085e
Merge branch 'main' into lance
AyushExel Nov 30, 2023
574bd11
Merge branch 'main' into lance
AyushExel Nov 30, 2023
fcdb151
Merge branch 'main' into lance
AyushExel Dec 2, 2023
5af04d7
Merge branch 'main' into lance
AyushExel Dec 4, 2023
83dca47
Merge branch 'main' into lance
AyushExel Dec 11, 2023
3df5873
Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py
AyushExel Dec 11, 2023
dad693b
Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py
AyushExel Dec 11, 2023
b3322fd
Update autogen/agentchat/contrib/retrieve_user_proxy_agent.py
AyushExel Dec 11, 2023
b24ceee
Merge branch 'main' into lance
thinkall Dec 12, 2023
a1a7857
upadte testing
AyushExel Dec 12, 2023
eab4c4d
Merge remote-tracking branch 'refs/remotes/origin/lance' into lance
AyushExel Dec 12, 2023
b9deeaf
rename
AyushExel Dec 12, 2023
b669aee
improve coverage
AyushExel Dec 12, 2023
154fa3d
update
AyushExel Dec 12, 2023
c1c6532
update notebook
AyushExel Dec 12, 2023
1cb6931
update
AyushExel Dec 12, 2023
1f02f65
update docstring
AyushExel Dec 12, 2023
96a4136
Test hide bot comments
thinkall Dec 12, 2023
274950e
Revert "Test hide bot comments"
thinkall Dec 12, 2023
795dfbd
Add hide bot comments
thinkall Dec 12, 2023
0de53a3
Revert "Add hide bot comments"
thinkall Dec 12, 2023
bdf45cf
Add comment-hider
thinkall Dec 12, 2023
40edc95
Revert "Add comment-hider"
thinkall Dec 12, 2023
9648e4b
Add hide-comment-action
thinkall Dec 12, 2023
27d8d30
Revert "Add hide-comment-action"
thinkall Dec 12, 2023
5f341a6
Update coverage for retrievechat
thinkall Dec 12, 2023
38eac14
Move retrieve tests to the same folder
thinkall Dec 12, 2023
d1e6078
Sync changes in main branch
thinkall Dec 12, 2023
9397974
Fix import error in tests
thinkall Dec 12, 2023
6a5ea67
update
AyushExel Dec 12, 2023
76d1d36
Merge remote-tracking branch 'refs/remotes/origin/lance' into lance
AyushExel Dec 12, 2023
232eaba
add custom_vectordb test
AyushExel Dec 12, 2023
b37e03f
update test
AyushExel Dec 12, 2023
680e37d
update dosctring
AyushExel Dec 12, 2023
ccef9ca
update test
AyushExel Dec 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ jobs:
pip install docker
pip install qdrant_client[fastembed]
pip install -e .[retrievechat]
pip install chromadb
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
coverage run -a -m pytest test/agentchat/contrib/retrievers
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ jobs:
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
pip install chromadb
pip uninstall -y openai
- name: Test RetrieveChat
run: |
pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
pytest test/agentchat/contrib/retrievers
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib
coverage run -a -m pytest test/agentchat/contrib/retrievers
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
Expand Down
1 change: 0 additions & 1 deletion autogen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .agentchat import *
from .code_utils import DEFAULT_MODEL, FAST_MODEL


# Set the root logger.
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, List, Optional

from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
from autogen.agentchat.contrib.retriever.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
import logging

logger = logging.getLogger(__name__)
Expand Down
119 changes: 80 additions & 39 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import re

try:
import chromadb
except ImportError:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
from autogen.agentchat.contrib.retriever.retrieve_utils import TEXT_FORMATS
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code
from autogen.agentchat.contrib.retriever import get_retriever

from autogen import logger

from typing import Callable, Dict, Optional, Union, List, Tuple, Any
Expand Down Expand Up @@ -94,12 +92,14 @@ def __init__(
The dict can contain the following keys: "content", "role", "name", "function_call".
retrieve_config (dict or None): config for the retrieve agent.
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
- retriever_type (Optional, str): the type of the retriever.
- retriever_path (Optional, str): the path to use for retriever-realted operations. Default is `~/autogen`.
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- client (Optional, Any): the vectordb client/connection. If key not provided, the Retreiver class should handle it.
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
the url to a single file or a list of directories, files and urls.
Default is None, which works only if the collection is already created.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
Expand All @@ -123,8 +123,14 @@ def __init__(
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb.
Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
- db_mode (Optional, str): the mode to create the vector db. Possible values are "get", "recreate", "create". Default is "recreate" to
keep the workflow less error-prone. If "get", will try to get an existing collection. If "recreate", will recreate a collection
if the collection already exists. If "create", will create a collection if the collection doesn't exist.
AyushExel marked this conversation as resolved.
Show resolved Hide resolved
Raises ValueError if:
* the collection doesn't exist and "get" is used.
* the collection already exists and "create" is used.
- get_or_create (Optional, bool): [Depricated] if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None.
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
Expand All @@ -136,7 +142,7 @@ def __init__(
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).

Example of overriding retrieve_docs:
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
If you want to set up a customized vector db, and it's not compatible with retriever, you can easily plug in it with below code.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
Expand Down Expand Up @@ -166,10 +172,12 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
human_input_mode=human_input_mode,
**kwargs,
)

self.retriever = None
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._retriever_type = self._retrieve_config.get("retriever_type")
self._retriever_path = self._retrieve_config.get("retriever_path", "~/autogen")
AyushExel marked this conversation as resolved.
Show resolved Hide resolved
self._task = self._retrieve_config.get("task", "default")
self._client = self._retrieve_config.get("client", chromadb.Client())
self._client = self._retrieve_config.get("client", None)
self._docs_path = self._retrieve_config.get("docs_path", None)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
Expand All @@ -188,7 +196,6 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
Expand All @@ -202,6 +209,26 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._search_string = "" # the search string used in the current query
self._db_mode = self._retrieve_config.get("db_mode")
self._get_or_create = self._retrieve_config.get("get_or_create")
if self._db_mode is not None and self._get_or_create is not None:
logger.warning(
colored(
"Warning: db_mode and get_or_create are both set. get_or_create will be ignored. get_or_create is depricated",
"yellow",
)
)
self._get_or_create = None
elif self._db_mode is None and self._get_or_create is None: # if both not set, set db_mode's default value
self._db_mode = "recreate"
elif self._get_or_create:
logger.warning(
colored(
"Warning: get_or_create is depricated and will be removed from future versions. Use `db_mode` instead",
"yellow",
)
)

# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
Expand Down Expand Up @@ -362,13 +389,9 @@ def _generate_retrieve_user_reply(

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
parameters and add your own parameters with default values, and keep the results in below type.

Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
to `chromadb.api.types.QueryResult` as an example.
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional.
ids: List[string]
documents: List[List[string]]

Expand All @@ -377,33 +400,51 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
n_results (int): the number of results to be retrieved. Default is 20.
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
"""
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._client = create_vector_db_from_dir(
dir_path=self._docs_path,
if not self.retriever:
retriever_class = get_retriever(self._retriever_type)
self.retriever = retriever_class(
path=self._retriever_path,
name=self._collection_name,
embedding_model_name=self._embedding_model,
embedding_function=self._embedding_function,
max_tokens=self._chunk_token_size,
client=self._client,
collection_name=self._collection_name,
chunk_mode=self._chunk_mode,
must_break_at_empty_line=self._must_break_at_empty_line,
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
custom_text_split_function=self.custom_text_split_function,
client=self._client,
custom_text_types=self._custom_text_types,
recursive=self._recursive,
)
self._collection = True
self._get_or_create = True

results = query_vector_db(
query_texts=[problem],
n_results=n_results,
search_string=search_string,
client=self._client,
collection_name=self._collection_name,
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
if self._db_mode:
if self._db_mode not in ["get", "recreate", "create"]:
raise ValueError(
f"db_mode {self._db_mode} is not supported. Possible values are 'get', 'recreate', 'create'."
)
if self._db_mode == "get":
if not self.retriever.index_exists:
raise ValueError("The index doesn't exist. Please set db_mode to 'recreate' or 'create'.")
self.retriever.use_existing_index()
elif self._db_mode == "recreate":
logger.info("Trying to create index. If the index already exists, it will be recreated.")
self.retriever.ingest_data(self._docs_path, overwrite=True)
elif self._db_mode == "create":
logger.info("Trying to create index.")
if self.retriever.index_exists:
raise ValueError("The index already exists. Please set db_mode to 'get' or 'recreate'.")
self.retriever.ingest_data(self._docs_path, overwrite=False)

elif self._get_or_create is not None:
if self._get_or_create and self.retriever.index_exists:
logger.info("Trying to use existing collection.")
self.retriever.use_existing_index()
else:
logger.info("Trying to create index.")
self.retriever.ingest_data(self._docs_path, overwrite=False)

results = self.retriever.query(
texts=[problem],
top_k=n_results,
filter=search_string,
)
self._search_string = search_string
self._results = results
Expand Down
1 change: 1 addition & 0 deletions autogen/agentchat/contrib/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .retrieve_utils import get_retriever
91 changes: 91 additions & 0 deletions autogen/agentchat/contrib/retriever/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from typing import List, Union, Callable, Any


class Retriever(ABC):
def __init__(
self,
path="./db",
name="vectorstore",
embedding_model_name="all-MiniLM-L6-v2",
embedding_function=None,
max_tokens: int = 4000,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
client=None,
# TODO: add support for custom text types and recurisive
custom_text_types: str = None,
recursive: bool = True,
):
"""
Args:
path: path to the folder where the database is stored
name: name of the database
embedding_model_name: name of the embedding model to use
embedding_function: function to use to embed the text
max_tokens: maximum number of tokens to embed
chunk_mode: mode to chunk the text. Can be "multi_lines" or "single_line"
must_break_at_empty_line: chunk will only break at empty line if True. Default is True.
If chunk_mode is "one_line", this parameter will be ignored.
custom_text_split_function: custom function to split the text into chunks
AyushExel marked this conversation as resolved.
Show resolved Hide resolved
client: client to use to connect to the database
custom_text_types: custom text types to ingest
recursive: whether to recursively ingest the files in the directory
"""
self.path = path
self.name = name
self.embedding_model_name = embedding_model_name
self.embedding_function = embedding_function
self.max_tokens = max_tokens
self.chunk_mode = chunk_mode
self.must_break_at_empty_line = must_break_at_empty_line
self.custom_text_split_function = custom_text_split_function
self.client = client
self.custom_text_types = custom_text_types
self.recursive = recursive

self.init_db()

@abstractmethod
def ingest_data(self, data_dir, overwrite: bool = False):
"""
Create a vector database from a directory of files.
Args:
data_dir: path to the directory containing the text files
overwrite: overwrite the existing database if True
"""
pass

@abstractmethod
def use_existing_index(self):
"""
Open an existing index.
"""
pass

@abstractmethod
def query(self, texts: List[str], top_k: int = 10, search_string: Any = None):
"""
Query the database.
Args:
texts: list of texts to query
top_k: number of results to return
search_string: string to filter the results
"""
pass

@abstractmethod
def init_db(self):
"""
Initialize the database.
"""
pass

@property
@abstractmethod
def index_exists(self):
"""
Check if the index exists in the database.
"""
pass
Loading
Loading