Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce test time with a global LLM mock #685

Merged
merged 11 commits into from
Jan 5, 2024
2 changes: 2 additions & 0 deletions metagpt/actions/invoice_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ async def _unzip(file_path: Path) -> Path:
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
for result in ocr_result[0]:
result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs
return ocr_result

async def run(self, file_path: Path, *args, **kwargs) -> list:
Expand Down
83 changes: 37 additions & 46 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,20 @@
import os
import re
import uuid
from typing import Optional

import pytest

from metagpt.config import CONFIG, Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_llm import MockLLM


class MockLLM(OpenAILLM):
rsp_cache: dict = {}

async def original_aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
):
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp

async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
) -> str:
if msg not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
logger.info(f"Added '{rsp[:20]}' ... to response cache")
self.rsp_cache[msg] = rsp
return rsp
else:
logger.info("Use response cache")
return self.rsp_cache[msg]
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
ALLOW_OPENAI_API_CALL = os.environ.get(
"ALLOW_OPENAI_API_CALL", True
) # NOTE: should change to default False once mock is complete


@pytest.fixture(scope="session")
Expand All @@ -76,16 +39,37 @@ def rsp_cache():
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(new_rsp_cache_file_path, "w") as f2:
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
with open(new_rsp_cache_file_path, "w") as f2:
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)


@pytest.fixture(scope="function")
def llm_mock(rsp_cache, mocker):
llm = MockLLM()
# Hook to capture the test result
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
outcome = yield
rep = outcome.get_result()
if rep.when == "call":
item.test_outcome = rep


@pytest.fixture(scope="function", autouse=True)
def llm_mock(rsp_cache, mocker, request):
llm = MockLLM(allow_open_api_call=ALLOW_OPENAI_API_CALL)
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
yield mocker
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
if llm.rsp_candidates:
for rsp_candidate in llm.rsp_candidates:
cand_key = list(rsp_candidate.keys())[0]
cand_value = list(rsp_candidate.values())[0]
if cand_key not in llm.rsp_cache:
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
llm.rsp_cache.update(rsp_candidate)
RSP_CACHE_NEW.update(rsp_candidate)


class Context:
Expand Down Expand Up @@ -170,6 +154,13 @@ def init_config():
Config()


@pytest.fixture(scope="function")
def new_filename(mocker):
garylin2099 marked this conversation as resolved.
Show resolved Hide resolved
# NOTE: Mock new filename to make reproducible llm aask, should consider changing after implementing requirement segmentation
mocker.patch("metagpt.utils.file_repository.FileRepository.new_filename", lambda: "20240101")
yield mocker


@pytest.fixture
def aiohttp_mocker(mocker):
class MockAioResponse:
Expand Down
184 changes: 117 additions & 67 deletions tests/data/rsp_cache.json

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/metagpt/actions/test_debug_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def test_player_calculate_score_with_multiple_aces(self):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_debug_error():
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
ctx = RunCodeContext(
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_design_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
for prd in inputs:
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_design_api_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api_review():
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
api_design = """
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_generate_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_generate_questions():
action = GenerateQuestions()
rsp = await action.run(context)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_invoice_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async def test_generate_table(invoice_path: Path, expected_result: dict):
("invoice_path", "query", "expected_result"),
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
)
@pytest.mark.usefixtures("llm_mock")
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
invoice_path = TEST_DATA_PATH / invoice_path
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_prepare_interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_prepare_interview():
action = PrepareInterview()
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_project_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
Expand Down
9 changes: 1 addition & 8 deletions tests/metagpt/actions/test_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@

import pytest

from metagpt.actions import CollectLinks, research


@pytest.mark.asyncio
async def test_action():
action = CollectLinks()
result = await action.run(topic="baidu")
assert result
from metagpt.actions import research


@pytest.mark.asyncio
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_summarize_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def get_body(self):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_summarize_code():
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_talk_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
),
],
)
@pytest.mark.usefixtures("llm_mock")
async def test_prompt(agent_description, language, context, knowledge, history_summary):
# Prerequisites
CONFIG.agent_description = agent_description
Expand Down
3 changes: 0 additions & 3 deletions tests/metagpt/actions/test_write_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code():
context = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
Expand All @@ -45,7 +44,6 @@ async def test_write_code():


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
llm = LLM()
Expand All @@ -54,7 +52,6 @@ async def test_write_code_directly():


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_deps():
# Prerequisites
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_review(capfd):
code = """
def add(a, b):
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ def greet(self):
],
ids=["google", "numpy", "sphinx"],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)
assert part in ret


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write():
code = await WriteDocstring.write_docstring(__file__)
assert code
Expand Down
3 changes: 1 addition & 2 deletions tests/metagpt/actions/test_write_prd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd():
async def test_write_prd(new_filename):
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_prd_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd_review():
prd = """
Introduction: This is a new feature for our product.
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_review():
write_review = WriteReview()
review = await write_review.run(CONTEXT)
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_write_teaching_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
("topic", "context"),
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_teaching_plan_part(topic, context):
action = WriteTeachingPlanPart(topic=topic, context=context)
rsp = await action.run()
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_test():
code = """
import random
Expand All @@ -40,7 +39,6 @@ def generate(self, max_y: int, max_x: int):


@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_invalid_code(mocker):
# Mock the _aask method to return an invalid code string
mocker.patch.object(WriteTest, "_aask", return_value="Invalid Code String")
Expand Down
2 changes: 0 additions & 2 deletions tests/metagpt/actions/test_write_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
@pytest.mark.usefixtures("llm_mock")
async def test_write_directory(language: str, topic: str):
ret = await WriteDirectory(language=language).run(topic=topic)
assert isinstance(ret, dict)
Expand All @@ -30,7 +29,6 @@ async def test_write_directory(language: str, topic: str):
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_content(language: str, topic: str, directory: Dict):
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
assert isinstance(ret, str)
Expand Down
30 changes: 20 additions & 10 deletions tests/metagpt/document_store/test_qdrant_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
]


def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"


def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
Expand All @@ -42,30 +52,30 @@ def test_qdrant_store():
qdrant_store.add("Book", points)
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert_almost_equal(results[1]["score"], 0.9961650411397226)
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert_almost_equal(results[0]["score"], 0.999106722578389)
assert_almost_equal(results[0]["vector"], [0.7363563179969788, 0.6765939593315125])
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
assert_almost_equal(results[1]["score"], 0.9961650411397226)
assert_almost_equal(results[1]["vector"], [0.7662628889083862, 0.6425272226333618])
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
assert_almost_equal(results[0]["score"], 0.9100373450784073)
assert results[1]["id"] == 9
assert results[1]["score"] == 0.7127610621127889
assert_almost_equal(results[1]["score"], 0.7127610621127889)
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
assert_almost_equal(results[0]["vector"], [0.35037919878959656, 0.9366079568862915])
assert_almost_equal(results[1]["vector"], [0.9999677538871765, 0.00802854634821415])
8 changes: 8 additions & 0 deletions tests/metagpt/provider/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest


@pytest.fixture(autouse=True)
def llm_mock(rsp_cache, mocker, request):
# An empty fixture to overwrite the global llm_mock fixture
# because in provider folder, we want to test the aask and aask functions for the specific models
pass
Loading
Loading