Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support no chromadb, sentence_transformers or pypdf #556

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import re

try:
import chromadb
except ImportError:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from typing import Callable, Dict, Optional, Union, List, Tuple, Any
from IPython import get_ipython
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code

from typing import Callable, Dict, Optional, Union, List, Tuple, Any
from IPython import get_ipython

try:
from termcolor import colored
except ImportError:
Expand All @@ -21,6 +15,15 @@ def colored(x, *args, **kwargs):
return x


try:
import chromadb
except ImportError:

class chromadb:
class Client:
pass


PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or
Expand Down
36 changes: 27 additions & 9 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,30 @@
import requests
from urllib.parse import urlparse
import glob
import chromadb

if chromadb.__version__ < "0.4.15":
from chromadb.api import API
else:
from chromadb.api import ClientAPI as API
from chromadb.api.types import QueryResult
import chromadb.utils.embedding_functions as ef
import logging
import pypdf
from autogen.token_count_utils import count_token

try:
import chromadb

if chromadb.__version__ < "0.4.15":
from chromadb.api import API
else:
from chromadb.api import ClientAPI as API
from chromadb.api.types import QueryResult
import chromadb.utils.embedding_functions as ef

HAS_CHROMADB = True
except ImportError:
HAS_CHROMADB = False

class API:
pass

class QueryResult(dict):
pass


logger = logging.getLogger(__name__)
TEXT_FORMATS = [
"txt",
Expand Down Expand Up @@ -88,6 +100,8 @@ def split_text_to_chunks(

def extract_text_from_pdf(file: str) -> str:
"""Extract text from PDF files"""
import pypdf

text = ""
with open(file, "rb") as f:
reader = pypdf.PdfReader(f)
Expand Down Expand Up @@ -240,6 +254,8 @@ def create_vector_db_from_dir(
Returns:
API: the chromadb client.
"""
if not HAS_CHROMADB:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the error a constant in the header.

if client is None:
client = chromadb.PersistentClient(path=db_path)
try:
Expand Down Expand Up @@ -314,6 +330,8 @@ class QueryResult(TypedDict):
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
"""
if not HAS_CHROMADB:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
if client is None:
client = chromadb.PersistentClient(path=db_path)
# the collection's embedding function is always the default one, but we want to use the one we used to create the
Expand Down
26 changes: 26 additions & 0 deletions test/agentchat/contrib/test_qdrant_retrievechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_retrievechat():
try:
# uninstall chromadb first
import chromadb

HAS_CHROMADB = True
os.system("pip uninstall -yq chromadb")
except ImportError:
HAS_CHROMADB = False

conversations = {}
# ChatCompletion.start_logging(conversations) # deprecated in v0.2

Expand Down Expand Up @@ -71,9 +80,22 @@ def test_retrievechat():
ragproxyagent.initiate_chat(assistant, problem=code_problem, silent=True)
print(conversations)

# reinstall chromadb
if HAS_CHROMADB:
os.system("pip install -q chromadb")


@pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed")
def test_qdrant_filter():
try:
# uninstall chromadb first
import chromadb

HAS_CHROMADB = True
os.system("pip uninstall -yq chromadb")
except ImportError:
HAS_CHROMADB = False

client = QdrantClient(":memory:")
create_qdrant_from_dir(dir_path="./website/docs", client=client, collection_name="autogen-docs")
results = query_qdrant(
Expand All @@ -86,6 +108,10 @@ def test_qdrant_filter():
)
assert len(results["ids"][0]) == 4

# reinstall chromadb
if HAS_CHROMADB:
os.system("pip install -q chromadb")


@pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed")
def test_qdrant_search():
Expand Down
Loading