Skip to content

Commit

Permalink
单测相关
Browse files Browse the repository at this point in the history
  • Loading branch information
HuiDBK committed Jul 19, 2024
1 parent 5287e02 commit 79334de
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 27 deletions.
48 changes: 30 additions & 18 deletions metagpt/rag/parser/omniparse/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import mimetypes
import os
from pathlib import Path
from typing import Union

import aiofiles
import httpx
from typing import Union

from metagpt.rag.schema import OmniParsedResult

Expand All @@ -14,6 +14,7 @@ class OmniParseClient:
OmniParse Server Client
OmniParse API Docs: https://docs.cognitivelab.in/api
"""

ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
Expand All @@ -33,11 +34,16 @@ def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=1
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"

async def __request_parse(
self, endpoint: str, method: str = "POST",
files: dict = None, params: dict = None,
data: dict = None, json: dict = None,
headers: dict = None, **kwargs,
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
请求api解析文档
Expand All @@ -61,9 +67,15 @@ async def __request_parse(
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url, method=method,
files=files, params=params, json=json, data=data,
headers=headers, timeout=self.max_timeout, **kwargs,
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -98,9 +110,9 @@ def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions:

@staticmethod
async def get_file_info(
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
) -> Union[bytes, tuple]:
"""
获取文件字节信息
Expand All @@ -120,7 +132,7 @@ async def get_file_info(
"""
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
async with aiofiles.open(filelike, 'rb') as file:
async with aiofiles.open(filelike, "rb") as file:
file_bytes = await file.read()

if only_bytes:
Expand Down Expand Up @@ -154,7 +166,7 @@ async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info})
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data

Expand All @@ -173,25 +185,25 @@ async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self.__request_parse(endpoint=endpoint, files={'file': file_info})
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data

async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析视频"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info})
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})

async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析音频"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info})
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})

async def parse_website(self, url: str) -> dict:
"""
解析网站
fixme:官方api还存在问题
"""
return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url})
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})
18 changes: 10 additions & 8 deletions metagpt/rag/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, Union, List
from typing import Any, ClassVar, List, Literal, Optional, Union

# from chromadb.api.types import CollectionMetadata
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
Expand Down Expand Up @@ -68,9 +68,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):

persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
# metadata: Optional[CollectionMetadata] = Field(
# default=None, description="Optional metadata to associate with the collection"
# )
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)


class ElasticsearchStoreConfig(BaseModel):
Expand Down Expand Up @@ -166,9 +166,9 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""

collection_name: str = Field(default="metagpt", description="The name of the collection.")
# metadata: Optional[CollectionMetadata] = Field(
# default=None, description="Optional metadata to associate with the collection"
# )
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)


class BM25IndexConfig(BaseIndexConfig):
Expand Down Expand Up @@ -219,12 +219,14 @@ def get_obj_metadata(obj: RAGObject) -> dict:

class OmniParseType(str, Enum):
"""OmniParse解析类型"""

PDF = "PDF"
DOCUMENT = "DOCUMENT"


class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""

result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型")
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
Expand Down
11 changes: 10 additions & 1 deletion tests/metagpt/rag/engines/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ def mock_get_rankers(self, mocker):
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")

@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor")

def test_from_docs(
self,
mocker,
mock_simple_directory_reader,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
Expand All @@ -53,6 +58,8 @@ def test_from_docs(
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor

# Setup
input_dir = "test_dir"
Expand All @@ -75,7 +82,9 @@ def test_from_docs(
)

# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
Expand Down
32 changes: 32 additions & 0 deletions tests/metagpt/rag/parser/test_omniparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parser.omniparse import OmniParseClient
from metagpt.rag.schema import OmniParsedResult


class TestOmniParseClient:
parse_client = OmniParseClient()

# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"

@pytest.fixture
def request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")

@pytest.mark.asyncio
async def test_parse_pdf(self, request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF)
assert parse_ret == mock_parsed_ret


class TestOmniParse:
def test_load_data(self):
pass

0 comments on commit 79334de

Please sign in to comment.