Skip to content

Commit

Permalink
Merge pull request #685 from garylin2099/llm_mock
Browse files Browse the repository at this point in the history
Reduce test time with a global LLM mock
  • Loading branch information
geekan authored Jan 5, 2024
2 parents 136b3f5 + bd4a35f commit 230192f
Show file tree
Hide file tree
Showing 51 changed files with 289 additions and 217 deletions.
2 changes: 2 additions & 0 deletions metagpt/actions/invoice_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ async def _unzip(file_path: Path) -> Path:
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
for result in ocr_result[0]:
result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs
return ocr_result

async def run(self, file_path: Path, *args, **kwargs) -> list:
Expand Down
83 changes: 37 additions & 46 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,20 @@
import os
import re
import uuid
from typing import Optional

import pytest

from metagpt.config import CONFIG, Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_llm import MockLLM


class MockLLM(OpenAILLM):
rsp_cache: dict = {}

async def original_aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
):
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp

async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
) -> str:
if msg not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
logger.info(f"Added '{rsp[:20]}' ... to response cache")
self.rsp_cache[msg] = rsp
return rsp
else:
logger.info("Use response cache")
return self.rsp_cache[msg]
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
ALLOW_OPENAI_API_CALL = os.environ.get(
"ALLOW_OPENAI_API_CALL", True
) # NOTE: should change to default False once mock is complete


@pytest.fixture(scope="session")
Expand All @@ -76,16 +39,37 @@ def rsp_cache():
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(new_rsp_cache_file_path, "w") as f2:
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
with open(new_rsp_cache_file_path, "w") as f2:
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)


@pytest.fixture(scope="function")
def llm_mock(rsp_cache, mocker):
llm = MockLLM()
# Hook to capture the test result
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
outcome = yield
rep = outcome.get_result()
if rep.when == "call":
item.test_outcome = rep


@pytest.fixture(scope="function", autouse=True)
def llm_mock(rsp_cache, mocker, request):
llm = MockLLM(allow_open_api_call=ALLOW_OPENAI_API_CALL)
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
yield mocker
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
if llm.rsp_candidates:
for rsp_candidate in llm.rsp_candidates:
cand_key = list(rsp_candidate.keys())[0]
cand_value = list(rsp_candidate.values())[0]
if cand_key not in llm.rsp_cache:
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
llm.rsp_cache.update(rsp_candidate)
RSP_CACHE_NEW.update(rsp_candidate)


class Context:
Expand Down Expand Up @@ -173,6 +157,13 @@ def init_config():
Config()


@pytest.fixture(scope="function")
def new_filename(mocker):
# NOTE: Mock new filename to make reproducible llm aask, should consider changing after implementing requirement segmentation
mocker.patch("metagpt.utils.file_repository.FileRepository.new_filename", lambda: "20240101")
yield mocker


@pytest.fixture
def aiohttp_mocker(mocker):
class MockAioResponse:
Expand Down
184 changes: 117 additions & 67 deletions tests/data/rsp_cache.json

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/metagpt/actions/test_debug_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def test_player_calculate_score_with_multiple_aces(self):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_debug_error():
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
ctx = RunCodeContext(
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_design_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
for prd in inputs:
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_design_api_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api_review():
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
api_design = """
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_generate_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_generate_questions():
action = GenerateQuestions()
rsp = await action.run(context)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_invoice_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async def test_generate_table(invoice_path: Path, expected_result: dict):
("invoice_path", "query", "expected_result"),
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
)
@pytest.mark.usefixtures("llm_mock")
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
invoice_path = TEST_DATA_PATH / invoice_path
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_prepare_interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_prepare_interview():
action = PrepareInterview()
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_project_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
Expand Down
9 changes: 1 addition & 8 deletions tests/metagpt/actions/test_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@

import pytest

from metagpt.actions import CollectLinks, research


@pytest.mark.asyncio
async def test_action():
action = CollectLinks()
result = await action.run(topic="baidu")
assert result
from metagpt.actions import research


@pytest.mark.asyncio
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_summarize_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def get_body(self):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_summarize_code():
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_talk_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
),
],
)
@pytest.mark.usefixtures("llm_mock")
async def test_prompt(agent_description, language, context, knowledge, history_summary):
# Prerequisites
CONFIG.agent_description = agent_description
Expand Down
3 changes: 0 additions & 3 deletions tests/metagpt/actions/test_write_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code():
context = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
Expand All @@ -45,7 +44,6 @@ async def test_write_code():


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
llm = LLM()
Expand All @@ -54,7 +52,6 @@ async def test_write_code_directly():


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_deps():
# Prerequisites
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_review(capfd):
code = """
def add(a, b):
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ def greet(self):
],
ids=["google", "numpy", "sphinx"],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)
assert part in ret


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write():
code = await WriteDocstring.write_docstring(__file__)
assert code
Expand Down
3 changes: 1 addition & 2 deletions tests/metagpt/actions/test_write_prd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd():
async def test_write_prd(new_filename):
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_prd_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd_review():
prd = """
Introduction: This is a new feature for our product.
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_review():
write_review = WriteReview()
review = await write_review.run(CONTEXT)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_teaching_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
("topic", "context"),
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_teaching_plan_part(topic, context):
action = WriteTeachingPlanPart(topic=topic, context=context)
rsp = await action.run()
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_test():
code = """
import random
Expand All @@ -40,7 +39,6 @@ def generate(self, max_y: int, max_x: int):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_invalid_code(mocker):
# Mock the _aask method to return an invalid code string
mocker.patch.object(WriteTest, "_aask", return_value="Invalid Code String")
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
@pytest.mark.usefixtures("llm_mock")
async def test_write_directory(language: str, topic: str):
ret = await WriteDirectory(language=language).run(topic=topic)
assert isinstance(ret, dict)
Expand All @@ -30,7 +29,6 @@ async def test_write_directory(language: str, topic: str):
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_content(language: str, topic: str, directory: Dict):
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
assert isinstance(ret, str)
Expand Down
30 changes: 20 additions & 10 deletions tests/metagpt/document_store/test_qdrant_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
]


def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"


def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
Expand All @@ -42,30 +52,30 @@ def test_qdrant_store():
qdrant_store.add("Book", points)
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert_almost_equal(results[1]["score"], 0.9961650411397226)
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert_almost_equal(results[0]["vector"], [0.7363563179969788, 0.6765939593315125])
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
assert_almost_equal(results[1]["score"], 0.9961650411397226)
assert_almost_equal(results[1]["vector"], [0.7662628889083862, 0.6425272226333618])
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
assert_almost_equal(results[0]["score"], 0.9100373450784073)
assert results[1]["id"] == 9
assert results[1]["score"] == 0.7127610621127889
assert_almost_equal(results[1]["score"], 0.7127610621127889)
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
assert_almost_equal(results[0]["vector"], [0.35037919878959656, 0.9366079568862915])
assert_almost_equal(results[1]["vector"], [0.9999677538871765, 0.00802854634821415])
8 changes: 8 additions & 0 deletions tests/metagpt/provider/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest


@pytest.fixture(autouse=True)
def llm_mock(rsp_cache, mocker, request):
# An empty fixture to overwrite the global llm_mock fixture
# because in provider folder, we want to test the aask and aask functions for the specific models
pass
Loading

0 comments on commit 230192f

Please sign in to comment.