diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index b24249bbe96..a9e85b55963 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -1,18 +1,12 @@ import re - -try: - import chromadb -except ImportError: - raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") +from typing import Callable, Dict, Optional, Union, List, Tuple, Any +from IPython import get_ipython from autogen.agentchat.agent import Agent from autogen.agentchat import UserProxyAgent from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db from autogen.token_count_utils import count_token from autogen.code_utils import extract_code -from typing import Callable, Dict, Optional, Union, List, Tuple, Any -from IPython import get_ipython - try: from termcolor import colored except ImportError: @@ -21,6 +15,15 @@ def colored(x, *args, **kwargs): return x +try: + import chromadb +except ImportError: + + class chromadb: + class Client: + pass + + PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You should follow the following steps to answer a question: Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index bc4fdfb7597..42b264b0296 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -3,18 +3,30 @@ import requests from urllib.parse import urlparse import glob -import chromadb - -if chromadb.__version__ < "0.4.15": - from chromadb.api import API -else: - from chromadb.api import ClientAPI as API -from chromadb.api.types import QueryResult -import chromadb.utils.embedding_functions as ef import logging -import pypdf from autogen.token_count_utils import count_token +try: + import chromadb + + if chromadb.__version__ < "0.4.15": + from chromadb.api import API + else: + from chromadb.api import ClientAPI as API + from chromadb.api.types import QueryResult + import chromadb.utils.embedding_functions as ef + + HAS_CHROMADB = True +except ImportError: + HAS_CHROMADB = False + + class API: + pass + + class QueryResult(dict): + pass + + logger = logging.getLogger(__name__) TEXT_FORMATS = [ "txt", @@ -88,6 +100,8 @@ def split_text_to_chunks( def extract_text_from_pdf(file: str) -> str: """Extract text from PDF files""" + import pypdf + text = "" with open(file, "rb") as f: reader = pypdf.PdfReader(f) @@ -240,6 +254,8 @@ def create_vector_db_from_dir( Returns: API: the chromadb client. """ + if not HAS_CHROMADB: + raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") if client is None: client = chromadb.PersistentClient(path=db_path) try: @@ -314,6 +330,8 @@ class QueryResult(TypedDict): metadatas: Optional[List[List[Metadata]]] distances: Optional[List[List[float]]] """ + if not HAS_CHROMADB: + raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") if client is None: client = chromadb.PersistentClient(path=db_path) # the collection's embedding function is always the default one, but we want to use the one we used to create the diff --git a/test/agentchat/contrib/test_qdrant_retrievechat.py b/test/agentchat/contrib/test_qdrant_retrievechat.py index 1d3c5afd6af..8f9548887e5 100644 --- a/test/agentchat/contrib/test_qdrant_retrievechat.py +++ b/test/agentchat/contrib/test_qdrant_retrievechat.py @@ -35,6 +35,15 @@ reason="do not run on MacOS or windows or dependency is not installed", ) def test_retrievechat(): + try: + # uninstall chromadb first + import chromadb + + HAS_CHROMADB = True + os.system("pip uninstall -yq chromadb") + except ImportError: + HAS_CHROMADB = False + conversations = {} # ChatCompletion.start_logging(conversations) # deprecated in v0.2 @@ -71,9 +80,22 @@ def test_retrievechat(): ragproxyagent.initiate_chat(assistant, problem=code_problem, silent=True) print(conversations) + # reinstall chromadb + if HAS_CHROMADB: + os.system("pip install -q chromadb") + @pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") def test_qdrant_filter(): + try: + # uninstall chromadb first + import chromadb + + HAS_CHROMADB = True + os.system("pip uninstall -yq chromadb") + except ImportError: + HAS_CHROMADB = False + client = QdrantClient(":memory:") create_qdrant_from_dir(dir_path="./website/docs", client=client, collection_name="autogen-docs") results = query_qdrant( @@ -86,6 +108,10 @@ def test_qdrant_filter(): ) assert len(results["ids"][0]) == 4 + # reinstall chromadb + if HAS_CHROMADB: + os.system("pip install -q chromadb") + @pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") def test_qdrant_search():