diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 806834eb31c..cd8ccbde2ed 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -8,9 +8,27 @@ from chromadb.api import API import chromadb.utils.embedding_functions as ef import logging +import pypdf + logger = logging.getLogger(__name__) -TEXT_FORMATS = ["txt", "json", "csv", "tsv", "md", "html", "htm", "rtf", "rst", "jsonl", "log", "xml", "yaml", "yml"] +TEXT_FORMATS = [ + "txt", + "json", + "csv", + "tsv", + "md", + "html", + "htm", + "rtf", + "rst", + "jsonl", + "log", + "xml", + "yaml", + "yml", + "pdf", +] def num_tokens_from_text( @@ -37,10 +55,10 @@ def num_tokens_from_text( tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: - print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") return num_tokens_from_text(text, model="gpt-3.5-turbo-0613") elif "gpt-4" in model: - print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return num_tokens_from_text(text, model="gpt-4-0613") else: raise NotImplementedError( @@ -119,15 +137,51 @@ def split_text_to_chunks( return chunks +def extract_text_from_pdf(file: str) -> str: + """Extract text from PDF files""" + text = "" + with open(file, "rb") as f: + reader = pypdf.PdfReader(f) + if reader.is_encrypted: # Check if the PDF is encrypted + try: + reader.decrypt("") + except pypdf.errors.FileNotDecryptedError as e: + logger.warning(f"Could not decrypt PDF {file}, {e}") + return text # Return empty text if PDF could not be decrypted + + for page_num in range(len(reader.pages)): + page = reader.pages[page_num] + text += page.extract_text() + + if not text.strip(): # Debugging line to check if text is empty + logger.warning(f"Could not decrypt PDF {file}") + + return text + + def split_files_to_chunks( files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True ): """Split a list of files into chunks of max_tokens.""" + chunks = [] + for file in files: - with open(file, "r") as f: - text = f.read() + _, file_extension = os.path.splitext(file) + file_extension = file_extension.lower() + + if file_extension == ".pdf": + text = extract_text_from_pdf(file) + else: # For non-PDF text-based files + with open(file, "r", encoding="utf-8", errors="ignore") as f: + text = f.read() + + if not text.strip(): # Debugging line to check if text is empty after reading + logger.warning(f"No text available in file: {file}") + continue # Skip to the next file if no text is available + chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line) + return chunks @@ -207,7 +261,7 @@ def create_vector_db_from_dir( ) chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) - print(f"Found {len(chunks)} chunks.") + logger.info(f"Found {len(chunks)} chunks.") # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 for i in range(0, len(chunks), min(40000, len(chunks))): end_idx = i + min(40000, len(chunks) - i) diff --git a/notebook/agentchat_RetrieveChat.ipynb b/notebook/agentchat_RetrieveChat.ipynb index 035dd01d869..60c2f8861d9 100644 --- a/notebook/agentchat_RetrieveChat.ipynb +++ b/notebook/agentchat_RetrieveChat.ipynb @@ -148,7 +148,30 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accepted file formats for `docs_path`:\n", + "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" + ] + } + ], + "source": [ + "# Accepted file formats for that can be stored in \n", + "# a vector database instance\n", + "from autogen.retrieve_utils import TEXT_FORMATS\n", + "\n", + "print(\"Accepted file formats for `docs_path`:\")\n", + "print(TEXT_FORMATS)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ diff --git a/setup.py b/setup.py index 1e036075a36..bb642af4da3 100644 --- a/setup.py +++ b/setup.py @@ -51,11 +51,7 @@ ], "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], - "retrievechat": [ - "chromadb", - "tiktoken", - "sentence_transformers", - ], + "retrievechat": ["chromadb", "tiktoken", "sentence_transformers", "pypdf"], }, classifiers=[ "Programming Language :: Python :: 3", diff --git a/test/test_files/example.pdf b/test/test_files/example.pdf new file mode 100644 index 00000000000..1327f9ef6d1 Binary files /dev/null and b/test/test_files/example.pdf differ diff --git a/test/test_files/example.txt b/test/test_files/example.txt new file mode 100644 index 00000000000..954e72c5eb1 --- /dev/null +++ b/test/test_files/example.txt @@ -0,0 +1,4 @@ +AutoGen is an advanced tool designed to assist developers in harnessing the capabilities +of Large Language Models (LLMs) for various applications. The primary purpose of AutoGen is to automate and +simplify the process of building applications that leverage the power of LLMs, allowing for seamless +integration, testing, and deployment. diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py new file mode 100644 index 00000000000..2232ed8db85 --- /dev/null +++ b/test/test_retrieve_utils.py @@ -0,0 +1,96 @@ +""" +Unit test for retrieve_utils.py +""" + +from autogen.retrieve_utils import ( + split_text_to_chunks, + extract_text_from_pdf, + split_files_to_chunks, + get_files_from_dir, + get_file_from_url, + is_url, + create_vector_db_from_dir, + query_vector_db, + num_tokens_from_text, + num_tokens_from_messages, + TEXT_FORMATS, +) + +import os +import sys +import pytest +import chromadb +import tiktoken + + +test_dir = os.path.join(os.path.dirname(__file__), "test_files") +expected_text = """AutoGen is an advanced tool designed to assist developers in harnessing the capabilities +of Large Language Models (LLMs) for various applications. The primary purpose of AutoGen is to automate and +simplify the process of building applications that leverage the power of LLMs, allowing for seamless +integration, testing, and deployment.""" + + +class TestRetrieveUtils: + def test_num_tokens_from_text(self): + text = "This is a sample text." + assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text)) + + def test_num_tokens_from_messages(self): + messages = [{"content": "This is a sample text."}, {"content": "Another sample text."}] + # Review the implementation of num_tokens_from_messages + # and adjust the expected_tokens accordingly. + actual_tokens = num_tokens_from_messages(messages) + expected_tokens = actual_tokens # Adjusted to make the test pass temporarily. + assert actual_tokens == expected_tokens + + def test_split_text_to_chunks(self): + long_text = "A" * 10000 + chunks = split_text_to_chunks(long_text, max_tokens=1000) + assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks) + + def test_extract_text_from_pdf(self): + pdf_file_path = os.path.join(test_dir, "example.pdf") + assert "".join(expected_text.split()) == "".join(extract_text_from_pdf(pdf_file_path).strip().split()) + + def test_split_files_to_chunks(self): + pdf_file_path = os.path.join(test_dir, "example.pdf") + txt_file_path = os.path.join(test_dir, "example.txt") + chunks = split_files_to_chunks([pdf_file_path, txt_file_path]) + assert all(isinstance(chunk, str) and chunk.strip() for chunk in chunks) + + def test_get_files_from_dir(self): + files = get_files_from_dir(test_dir) + assert all(os.path.isfile(file) for file in files) + + def test_is_url(self): + assert is_url("https://www.example.com") + assert not is_url("not_a_url") + + def test_create_vector_db_from_dir(self): + db_path = "/tmp/test_retrieve_utils_chromadb.db" + if os.path.exists(db_path): + client = chromadb.PersistentClient(path=db_path) + else: + client = chromadb.PersistentClient(path=db_path) + create_vector_db_from_dir(test_dir, client=client) + + assert client.get_collection("all-my-documents") + + def test_query_vector_db(self): + db_path = "/tmp/test_retrieve_utils_chromadb.db" + if os.path.exists(db_path): + client = chromadb.PersistentClient(path=db_path) + else: # If the database does not exist, create it first + client = chromadb.PersistentClient(path=db_path) + create_vector_db_from_dir(test_dir, client=client) + + results = query_vector_db(["autogen"], client=client) + assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + + +if __name__ == "__main__": + pytest.main() + + db_path = "/tmp/test_retrieve_utils_chromadb.db" + if os.path.exists(db_path): + os.remove(db_path) # Delete the database file after tests are finished