diff --git a/metagpt/rag/parser/omniparse/client.py b/metagpt/rag/parser/omniparse/client.py index 20cc70c4d..7386bff0d 100644 --- a/metagpt/rag/parser/omniparse/client.py +++ b/metagpt/rag/parser/omniparse/client.py @@ -1,10 +1,10 @@ import mimetypes import os from pathlib import Path +from typing import Union import aiofiles import httpx -from typing import Union from metagpt.rag.schema import OmniParsedResult @@ -14,6 +14,7 @@ class OmniParseClient: OmniParse Server Client OmniParse API Docs: https://docs.cognitivelab.in/api """ + ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"} ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"} ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"} @@ -33,11 +34,16 @@ def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=1 self.parse_website_endpoint = "/parse_website" self.parse_document_endpoint = "/parse_document" - async def __request_parse( - self, endpoint: str, method: str = "POST", - files: dict = None, params: dict = None, - data: dict = None, json: dict = None, - headers: dict = None, **kwargs, + async def _request_parse( + self, + endpoint: str, + method: str = "POST", + files: dict = None, + params: dict = None, + data: dict = None, + json: dict = None, + headers: dict = None, + **kwargs, ) -> dict: """ 请求api解析文档 @@ -61,9 +67,15 @@ async def __request_parse( headers.update(**_headers) async with httpx.AsyncClient() as client: response = await client.request( - url=url, method=method, - files=files, params=params, json=json, data=data, - headers=headers, timeout=self.max_timeout, **kwargs, + url=url, + method=method, + files=files, + params=params, + json=json, + data=data, + headers=headers, + timeout=self.max_timeout, + **kwargs, ) response.raise_for_status() return response.json() @@ -98,9 +110,9 @@ def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: @staticmethod async def get_file_info( - filelike: Union[str, bytes, Path], - bytes_filename: str = None, - only_bytes=True, + filelike: Union[str, bytes, Path], + bytes_filename: str = None, + only_bytes=True, ) -> Union[bytes, tuple]: """ 获取文件字节信息 @@ -120,7 +132,7 @@ async def get_file_info( """ if isinstance(filelike, (str, Path)): filename = os.path.basename(str(filelike)) - async with aiofiles.open(filelike, 'rb') as file: + async with aiofiles.open(filelike, "rb") as file: file_bytes = await file.read() if only_bytes: @@ -154,7 +166,7 @@ async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename """ self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info}) + resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) data = OmniParsedResult(**resp) return data @@ -173,7 +185,7 @@ async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult self.verify_file_ext(filelike, {".pdf"}) file_info = await self.get_file_info(filelike) endpoint = f"{self.parse_document_endpoint}/pdf" - resp = await self.__request_parse(endpoint=endpoint, files={'file': file_info}) + resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) data = OmniParsedResult(**resp) return data @@ -181,17 +193,17 @@ async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: s """解析视频""" self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info}) + return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: """解析音频""" self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info}) + return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) async def parse_website(self, url: str) -> dict: """ 解析网站 fixme:官方api还存在问题 """ - return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url}) + return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url}) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 24c597196..7f34a0be9 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,9 +1,9 @@ """RAG schemas.""" from enum import Enum from pathlib import Path -from typing import Any, ClassVar, Literal, Optional, Union, List +from typing import Any, ClassVar, List, Literal, Optional, Union -# from chromadb.api.types import CollectionMetadata +from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode @@ -68,9 +68,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") - # metadata: Optional[CollectionMetadata] = Field( - # default=None, description="Optional metadata to associate with the collection" - # ) + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class ElasticsearchStoreConfig(BaseModel): @@ -166,9 +166,9 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") - # metadata: Optional[CollectionMetadata] = Field( - # default=None, description="Optional metadata to associate with the collection" - # ) + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class BM25IndexConfig(BaseIndexConfig): @@ -219,12 +219,14 @@ def get_obj_metadata(obj: RAGObject) -> dict: class OmniParseType(str, Enum): """OmniParse解析类型""" + PDF = "PDF" DOCUMENT = "DOCUMENT" class OmniParseOptions(BaseModel): """OmniParse可选配置""" + result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型") parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型") max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时") diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 8c7a15be2..61b9816a5 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -37,6 +37,10 @@ def mock_get_rankers(self, mocker): def mock_get_response_synthesizer(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer") + @pytest.fixture + def mock_get_file_extractor(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor") + def test_from_docs( self, mocker, @@ -44,6 +48,7 @@ def test_from_docs( mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, + mock_get_file_extractor, ): # Mock mock_simple_directory_reader.return_value.load_data.return_value = [ @@ -53,6 +58,8 @@ def test_from_docs( mock_get_retriever.return_value = mocker.MagicMock() mock_get_rankers.return_value = [mocker.MagicMock()] mock_get_response_synthesizer.return_value = mocker.MagicMock() + file_extractor = mocker.MagicMock() + mock_get_file_extractor.return_value = file_extractor # Setup input_dir = "test_dir" @@ -75,7 +82,9 @@ def test_from_docs( ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_simple_directory_reader.assert_called_once_with( + input_dir=input_dir, input_files=input_files, file_extractor=file_extractor + ) mock_get_retriever.assert_called_once() mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) diff --git a/tests/metagpt/rag/parser/test_omniparse.py b/tests/metagpt/rag/parser/test_omniparse.py new file mode 100644 index 000000000..79b173d5b --- /dev/null +++ b/tests/metagpt/rag/parser/test_omniparse.py @@ -0,0 +1,32 @@ +import pytest + +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.rag.parser.omniparse import OmniParseClient +from metagpt.rag.schema import OmniParsedResult + + +class TestOmniParseClient: + parse_client = OmniParseClient() + + # test data + TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx" + TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf" + TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4" + TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3" + + @pytest.fixture + def request_parse(self, mocker): + return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_parse_pdf(self, request_parse): + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + request_parse.return_value = mock_parsed_ret.model_dump() + parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF) + assert parse_ret == mock_parsed_ret + + +class TestOmniParse: + def test_load_data(self): + pass