-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Retriever API #619
New Retriever API #619
Changes from 25 commits
7b967ee
144ed87
42cb753
d74d043
6f705da
4baf0ae
3f547d4
77aa60f
af61aeb
f400fcb
65ad434
99fd892
369de53
af4b247
5b9a43e
9978196
f2739cf
ffcd207
66ca291
079886a
6903177
e733ce3
0800bf6
8cf4cb2
055e7d7
efc64f8
4f90d31
10f1b23
8a68640
25ff5c0
1d06955
b214aa6
01d305f
4d622df
b4cd6c4
9beb1be
f1ccc4b
7f8085e
574bd11
fcdb151
5af04d7
83dca47
3df5873
dad693b
b3322fd
b24ceee
a1a7857
eab4c4d
b9deeaf
b669aee
154fa3d
c1c6532
1cb6931
1f02f65
96a4136
274950e
795dfbd
0de53a3
bdf45cf
40edc95
9648e4b
27d8d30
5f341a6
38eac14
d1e6078
9397974
6a5ea67
76d1d36
232eaba
b37e03f
680e37d
ccef9ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,11 @@ | ||
import re | ||
|
||
try: | ||
import chromadb | ||
except ImportError: | ||
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") | ||
from autogen.agentchat.agent import Agent | ||
from autogen.agentchat import UserProxyAgent | ||
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS | ||
from autogen.agentchat.contrib.retriever.retrieve_utils import TEXT_FORMATS | ||
from autogen.token_count_utils import count_token | ||
from autogen.code_utils import extract_code | ||
from autogen.agentchat.contrib.retriever import get_retriever | ||
|
||
from typing import Callable, Dict, Optional, Union, List, Tuple, Any | ||
from IPython import get_ipython | ||
|
@@ -95,10 +92,9 @@ def __init__( | |
To use default config, set to None. Otherwise, set to a dictionary with the following keys: | ||
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System | ||
prompt will be different for different tasks. The default value is `default`, which supports both code and qa. | ||
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` | ||
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. | ||
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file, | ||
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created. | ||
- client (Optional, Any): the vectordb client/connection. If key not provided, the Retreiver class should handle it. | ||
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, | ||
thinkall marked this conversation as resolved.
Show resolved
Hide resolved
|
||
or the url to a single file. Default is None, which works only if the collection is already created. | ||
- collection_name (Optional, str): the name of the collection. | ||
If key not provided, a default name `autogen-docs` will be used. | ||
- model (Optional, str): the model to use for the retrieve chat. | ||
|
@@ -122,9 +118,9 @@ def __init__( | |
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". | ||
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. | ||
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. | ||
- get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb. | ||
Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None. | ||
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. | ||
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest we use a new parameter such as
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both are OK to me. I suggested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @thinkall I've actually changed the logic of creating and querying retrievers and I think this change probably won't be required. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think your new design still misses |
||
This is the same as that used in retriever. Default is False. Will be set to False if docs_path is None. | ||
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. | ||
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models. | ||
- custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings. | ||
|
@@ -135,7 +131,7 @@ def __init__( | |
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). | ||
|
||
Example of overriding retrieve_docs: | ||
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. | ||
If you have set up a customized vector db, and it's not compatible with retriever, you can easily plug in it with below code. | ||
```python | ||
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): | ||
def query_vector_db( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this sample code still correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it should still be supported. That's what qdrant's file it doing right? |
||
|
@@ -165,10 +161,12 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = | |
human_input_mode=human_input_mode, | ||
**kwargs, | ||
) | ||
|
||
self.retriever = None | ||
self._retrieve_config = {} if retrieve_config is None else retrieve_config | ||
self._retriever_type = self._retrieve_config.get("retriever_type") | ||
self._retriever_path = self._retrieve_config.get("retriever_path", "~/autogen") | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._task = self._retrieve_config.get("task", "default") | ||
self._client = self._retrieve_config.get("client", chromadb.Client()) | ||
self._client = self._retrieve_config.get("client", None) | ||
self._docs_path = self._retrieve_config.get("docs_path", None) | ||
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") | ||
self._model = self._retrieve_config.get("model", "gpt-4") | ||
|
@@ -355,13 +353,9 @@ def _generate_retrieve_user_reply( | |
|
||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): | ||
"""Retrieve docs based on the given problem and assign the results to the class property `_results`. | ||
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not | ||
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current | ||
parameters and add your own parameters with default values, and keep the results in below type. | ||
|
||
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of | ||
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer | ||
to `chromadb.api.types.QueryResult` as an example. | ||
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. | ||
ids: List[string] | ||
documents: List[List[string]] | ||
|
||
|
@@ -370,33 +364,35 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = | |
n_results (int): the number of results to be retrieved. | ||
search_string (str): only docs containing this string will be retrieved. | ||
""" | ||
if not self._collection or not self._get_or_create: | ||
print("Trying to create collection.") | ||
self._client = create_vector_db_from_dir( | ||
dir_path=self._docs_path, | ||
if not self.retriever: | ||
retriever_class = get_retriever(self._retriever_type) | ||
self.retriever = retriever_class( | ||
path=self._retriever_path, | ||
name=self._collection_name, | ||
embedding_model_name=self._embedding_model, | ||
embedding_function=self._embedding_function, | ||
max_tokens=self._chunk_token_size, | ||
client=self._client, | ||
collection_name=self._collection_name, | ||
chunk_mode=self._chunk_mode, | ||
must_break_at_empty_line=self._must_break_at_empty_line, | ||
embedding_model=self._embedding_model, | ||
get_or_create=self._get_or_create, | ||
embedding_function=self._embedding_function, | ||
custom_text_split_function=self.custom_text_split_function, | ||
client=self._client, | ||
custom_text_types=self._custom_text_types, | ||
recursive=self._recursive, | ||
) | ||
self._collection = True | ||
self._get_or_create = True | ||
|
||
results = query_vector_db( | ||
query_texts=[problem], | ||
n_results=n_results, | ||
search_string=search_string, | ||
client=self._client, | ||
collection_name=self._collection_name, | ||
embedding_model=self._embedding_model, | ||
embedding_function=self._embedding_function, | ||
if not self.retriever.index_exists() or not self._get_or_create: | ||
print("Trying to create index.") # TODO: logger | ||
self.retriever.ingest_data(self._docs_path) | ||
elif self._get_or_create: | ||
if self.retriever.index_exists(): | ||
print("Trying to use existing collection.") # TODO: logger | ||
self.retriever.use_existing_index() | ||
else: | ||
raise Exception("Requested to use existing index but it is not found!") | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
results = self.retriever.query( | ||
texts=[problem], | ||
top_k=n_results, | ||
filter=search_string, | ||
) | ||
self._search_string = search_string | ||
self._results = results | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Optional | ||
|
||
AVILABLE_RETRIEVERS = ["lanchedb", "chromadb"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you OK with it if we don't move qdrant here in this PR? @sonichi |
||
DEFAULT_RETRIEVER = "lancedb" | ||
|
||
|
||
def get_retriever(type: Optional[str] = None): | ||
"""Return a retriever instance.""" | ||
type = type or DEFAULT_RETRIEVER | ||
if type == "chromadb": | ||
from .chromadb import ChromaDB | ||
|
||
return ChromaDB | ||
elif type == "lancedb": | ||
from .lancedb import LanceDB | ||
|
||
return LanceDB | ||
else: | ||
raise ValueError(f"Unknown retriever type {type}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Union, Callable, Any | ||
|
||
|
||
class Retriever(ABC): | ||
def __init__( | ||
self, | ||
path="./db", | ||
name="vectorstore", | ||
embedding_model_name="all-MiniLM-L6-v2", | ||
embedding_function=None, | ||
max_tokens: int = 4000, | ||
chunk_mode: str = "multi_lines", | ||
must_break_at_empty_line: bool = True, | ||
custom_text_split_function: Callable = None, | ||
client=None, | ||
# TODO: add support for custom text types and recurisive | ||
custom_text_types: str = None, | ||
recursive: bool = True, | ||
): | ||
""" | ||
Args: | ||
path: path to the folder where the database is stored | ||
name: name of the database | ||
embedding_model_name: name of the embedding model to use | ||
embedding_function: function to use to embed the text | ||
max_tokens: maximum number of tokens to embed | ||
chunk_mode: mode to chunk the text. Can be "multi_lines" or "single_line" | ||
must_break_at_empty_line: whether to break the text at empty lines when chunking | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
custom_text_split_function: custom function to split the text into chunks | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
self.path = path | ||
self.name = name | ||
self.embedding_model_name = embedding_model_name | ||
self.embedding_function = embedding_function | ||
self.max_tokens = max_tokens | ||
self.chunk_mode = chunk_mode | ||
self.must_break_at_empty_line = must_break_at_empty_line | ||
self.custom_text_split_function = custom_text_split_function | ||
self.client = client | ||
|
||
self.init_db() | ||
|
||
@abstractmethod | ||
def ingest_data(self, data_dir): | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Create a vector database from a directory of files. | ||
Args: | ||
data_dir: path to the directory containing the text files | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def use_existing_index(self): | ||
""" | ||
Open an existing index. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def query(self, texts: List[str], top_k: int = 10, filter: Any = None): | ||
""" | ||
Query the database. | ||
Args: | ||
query: query string or list of query strings | ||
top_k: number of results to return | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def init_db(self): | ||
""" | ||
Initialize the database. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def index_exists(self): | ||
""" | ||
Check if the index exists in the database. | ||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import List | ||
from .base import Retriever | ||
from .retrieve_utils import split_text_to_chunks, extract_text_from_pdf, split_files_to_chunks, get_files_from_dir | ||
|
||
try: | ||
import chromadb | ||
|
||
if chromadb.__version__ < "0.4.15": | ||
from chromadb.api import API | ||
else: | ||
from chromadb.api import ClientAPI as API | ||
from chromadb.api.types import QueryResult | ||
import chromadb.utils.embedding_functions as ef | ||
except ImportError: | ||
raise ImportError("Please install chromadb: pip install chromadb") | ||
|
||
|
||
class ChromaDB(Retriever): | ||
def init_db(self): | ||
self.client = chromadb.PersistentClient(path=self.path) | ||
self.embedding_function = ( | ||
ef.SentenceTransformerEmbeddingFunction(self.embedding_model_name) | ||
if self.embedding_function is None | ||
else self.embedding_function | ||
) | ||
self.collection = None | ||
|
||
def ingest_data(self, data_dir): | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Create a vector database from a directory of files. | ||
Args: | ||
data_dir: path to the directory containing the text files | ||
""" | ||
|
||
thinkall marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.collection = self.client.create_collection( | ||
self.name, | ||
embedding_function=self.embedding_function, | ||
# https://github.com/nmslib/hnswlib#supported-distances | ||
# https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184 | ||
# https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md | ||
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine | ||
) | ||
|
||
if self.custom_text_split_function is not None: | ||
chunks = split_files_to_chunks( | ||
get_files_from_dir(data_dir), custom_text_split_function=self.custom_text_split_function | ||
) | ||
else: | ||
chunks = split_files_to_chunks( | ||
get_files_from_dir(data_dir), self.max_tokens, self.chunk_mode, self.must_break_at_empty_line | ||
) | ||
print(f"Found {len(chunks)} chunks.") # | ||
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000 | ||
for i in range(0, len(chunks), min(40000, len(chunks))): | ||
end_idx = i + min(40000, len(chunks) - i) | ||
self.collection.upsert( | ||
documents=chunks[i:end_idx], | ||
ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc | ||
) | ||
|
||
def use_existing_index(self): | ||
self.collection = self.client.get_collection(name=self.name, embedding_function=self.embedding_function) | ||
|
||
def query(self, texts: List[str], top_k: int = 10, filter: str = None): | ||
# the collection's embedding function is always the default one, but we want to use the one we used to create the | ||
# collection. So we compute the embeddings ourselves and pass it to the query function. | ||
|
||
query_embeddings = self.embedding_function(texts) | ||
# Query/search n most similar results. You can also .get by id | ||
results = self.collection.query( | ||
query_embeddings=query_embeddings, | ||
n_results=top_k, | ||
where_document={"$contains": filter} if filter else None, # optional filter | ||
) | ||
return results | ||
|
||
def index_exists(self): | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
self.client.get_collection(name=self.name, embedding_function=self.embedding_function) | ||
AyushExel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Not sure if there's an explicit way to check if a collection exists for chromadb | ||
return True | ||
except Exception: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qdrant is another vector db like lancedb and chromadb, would you like to remove the
QdrantRetrieveUserProxyAgent
and keep onlyRetrieveUserProxyAgent
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add @Anush008 as a reviewer when the change involves qdrant in any PR.