Skip to content

Commit

Permalink
Merge branch 'main' into fix-teach
Browse files Browse the repository at this point in the history
  • Loading branch information
Hk669 authored Apr 17, 2024
2 parents fb00783 + c4e5703 commit a69d4df
Show file tree
Hide file tree
Showing 9 changed files with 1,178 additions and 1,053 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ test/agentchat/test_agent_scripts/*
# test cache
.cache_test
.db
local_cache


notebook/result.png
Expand Down
16 changes: 13 additions & 3 deletions autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
from typing import Callable, Dict, List, Optional

from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.agentchat.contrib.vectordb.utils import (
chroma_results_to_query_results,
filter_results_by_distance,
get_logger,
)
from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks

logger = logging.getLogger(__name__)
logger = get_logger(__name__)

try:
import fastembed
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
except ImportError as e:
logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e


Expand Down Expand Up @@ -136,6 +140,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
collection_name=self._collection_name,
embedding_model=self._embedding_model,
)
results["contents"] = results.pop("documents")
results = chroma_results_to_query_results(results, "distances")
results = filter_results_by_distance(results, self._distance_threshold)

self._search_string = search_string
self._results = results


Expand Down Expand Up @@ -298,6 +307,7 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
"distances": [[result.score for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data
205 changes: 173 additions & 32 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/vectordb/chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ class ChromaVectorDB(VectorDB):
"""

def __init__(
self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs
self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
) -> None:
"""
Initialize the vector database.
Args:
client: chromadb.Client | The client object of the vector database. Default is None.
If provided, it will use the client object directly and ignore other arguments.
path: str | The path to the vector database. Default is None.
path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
Expand Down
3 changes: 3 additions & 0 deletions autogen/agentchat/contrib/vectordb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def error(self, msg, *args, color="light_red", **kwargs):
def critical(self, msg, *args, color="red", **kwargs):
super().critical(colored(msg, color), *args, **kwargs)

def fatal(self, msg, *args, color="red", **kwargs):
super().fatal(colored(msg, color), *args, **kwargs)


def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
logger = ColoredLogger(name, level)
Expand Down
24 changes: 16 additions & 8 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import hashlib
import os
import re
from typing import Callable, List, Tuple, Union
Expand Down Expand Up @@ -156,7 +157,7 @@ def split_files_to_chunks(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
):
) -> Tuple[List[str], List[dict]]:
"""Split a list of files into chunks of max_tokens."""

chunks = []
Expand Down Expand Up @@ -275,15 +276,22 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
return webpage_text


def _generate_file_name_from_url(url: str, max_length=255) -> str:
url_bytes = url.encode("utf-8")
hash = hashlib.blake2b(url_bytes).hexdigest()
parsed_url = urlparse(url)
file_name = os.path.basename(url)
file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
return file_name


def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
os.makedirs(save_path, exist_ok=True)
if os.path.isdir(save_path):
filename = os.path.basename(url)
if filename == "": # "www.example.com/"
filename = url.split("/")[-2]
filename = _generate_file_name_from_url(url)
save_path = os.path.join(save_path, filename)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Expand Down Expand Up @@ -327,7 +335,7 @@ def create_vector_db_from_dir(
dir_path: Union[str, List[str]],
max_tokens: int = 4000,
client: API = None,
db_path: str = "/tmp/chromadb.db",
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
Expand All @@ -347,7 +355,7 @@ def create_vector_db_from_dir(
dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False.
Expand Down Expand Up @@ -420,7 +428,7 @@ def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
db_path: str = "/tmp/chromadb.db",
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
Expand All @@ -433,7 +441,7 @@ def query_vector_db(
query_texts (List[str]): the list of strings which will be used to query the vector db.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
Expand Down
Loading

0 comments on commit a69d4df

Please sign in to comment.