Skip to content

Commit

Permalink
Merge pull request #615 from iorisa/fixbug/geekan/dev
Browse files Browse the repository at this point in the history
fixbug: timeout & prompt_format
  • Loading branch information
geekan authored Dec 24, 2023
2 parents a1f39d1 + e6a5e8e commit 8d1a3ce
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 36 deletions.
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ OPENAI_API_MODEL: "gpt-4-1106-preview"
MAX_TOKENS: 4096
RPM: 10
LLM_TYPE: OpenAI # Except for these three major models – OpenAI, MetaGPT LLM, and Azure – other large models can be distinguished based on the validity of the key.
TIMEOUT: 60 # Timeout for llm invocation

#### if Spark
#SPARK_APPID : "YOUR_APPID"
Expand Down
15 changes: 9 additions & 6 deletions metagpt/actions/action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic import BaseModel, create_model, root_validator, validator
from tenacity import retry, stop_after_attempt, wait_random_exponential

from metagpt.config import CONFIG
from metagpt.llm import BaseGPTAPI
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
Expand Down Expand Up @@ -260,9 +261,10 @@ async def _aask_v1(
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs)
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
logger.debug(f"llm raw output:\n{content}")
output_class = self.create_model_class(output_class_name, output_data_mapping)

Expand All @@ -289,13 +291,13 @@ def set_llm(self, llm):
def set_context(self, context):
self.set_recursive("context", context)

async def simple_fill(self, schema, mode):
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
prompt = self.compile(context=self.context, schema=schema, mode=mode)

if schema != "raw":
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
self.content = content
self.instruct_content = scontent
else:
Expand All @@ -304,7 +306,7 @@ async def simple_fill(self, schema, mode):

return self

async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
Expand All @@ -320,6 +322,7 @@ async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:param timeout: Timeout for llm invocation.
:return: self
"""
self.set_llm(llm)
Expand All @@ -328,12 +331,12 @@ async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
schema = self.schema

if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode)
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(schema=schema, mode=mode)
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
Expand Down
6 changes: 3 additions & 3 deletions metagpt/actions/design_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class WriteDesign(Action):
"clearly and in detail."
)

async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_format):
# Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory.
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
changed_prds = prds_file_repo.changed_files
Expand Down Expand Up @@ -81,11 +81,11 @@ async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)

async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
async def _new_system_design(self, context, schema=CONFIG.prompt_format):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
return node

async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_format):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)
Expand Down
6 changes: 3 additions & 3 deletions metagpt/actions/project_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class WriteTasks(Action):
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)

async def run(self, with_messages, schema=CONFIG.prompt_schema):
async def run(self, with_messages, schema=CONFIG.prompt_format):
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files

Expand Down Expand Up @@ -92,14 +92,14 @@ async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo
await self._save_pdf(task_doc=task_doc)
return task_doc

async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
async def _run_new_tasks(self, context, schema=CONFIG.prompt_format):
node = await PM_NODE.fill(context, self.llm, schema)
# prompt_template, format_example = get_template(templates, format)
# prompt = prompt_template.format(context=context, format_example=format_example)
# rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format)
return node

async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_format) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
node = await PM_NODE.fill(context, self.llm, schema)
task_doc.content = node.instruct_content.json(ensure_ascii=False)
Expand Down
1 change: 0 additions & 1 deletion metagpt/actions/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class WebBrowseAndSummarize(Action):
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = WebBrowserEngine(
options={}, # FIXME: REMOVE options?
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
run_func=browse_func,
)
Expand Down
6 changes: 3 additions & 3 deletions metagpt/actions/write_prd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class WritePRD(Action):
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)

async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
async def run(self, with_messages, schema=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
Expand Down Expand Up @@ -113,7 +113,7 @@ async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs)
# optimization in subsequent steps.
return ActionOutput(content=change_files.json(), instruct_content=change_files)

async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_format) -> ActionOutput:
# sas = SearchAndSummarize()
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
# rsp = ""
Expand All @@ -132,7 +132,7 @@ async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"

async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_format) -> Document:
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
Expand Down
10 changes: 8 additions & 2 deletions metagpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ def get_default_llm_provider_enum(self) -> LLMProviderEnum:

if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
if self.openai_api_key and self.openai_api_model:
logger.info(f"OpenAI API Model: {self.openai_api_model}")
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
model_name = model_mappings.get(provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
logger.info(f"API: {provider}")
return provider
Expand Down Expand Up @@ -187,6 +192,7 @@ def _update(self):
self.workspace_path = self.workspace_path / workspace_uid
self._ensure_workspace_exists()
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
self.timeout = int(self._get("TIMEOUT", 3))

def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
Expand Down
6 changes: 1 addition & 5 deletions metagpt/provider/azure_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
}
if configs:
kwargs.update(configs)
try:
default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0
except ValueError:
default_timeout = 0
kwargs["timeout"] = max(default_timeout, timeout)
kwargs["timeout"] = max(CONFIG.timeout, timeout)

return kwargs
8 changes: 2 additions & 6 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str
)

async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" # extract the message
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message

def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
Expand All @@ -143,11 +143,7 @@ def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
}
if configs:
kwargs.update(configs)
try:
default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0
except ValueError:
default_timeout = 0
kwargs["timeout"] = max(default_timeout, timeout)
kwargs["timeout"] = max(CONFIG.timeout, timeout)

return kwargs

Expand Down
2 changes: 1 addition & 1 deletion metagpt/roles/engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,4 @@ async def _new_summarize_actions(self):
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self._next_todo
return self.next_todo_action
10 changes: 7 additions & 3 deletions metagpt/roles/product_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
"""


from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
Expand Down Expand Up @@ -39,14 +38,19 @@ def __init__(self, **kwargs) -> None:
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)

async def _think(self) -> None:
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo:
self._set_state(1)
else:
self._set_state(0)
self.todo_action = any_to_name(WritePRD)
return self._rc.todo
return bool(self._rc.todo)

async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)

@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.todo_action
3 changes: 1 addition & 2 deletions metagpt/tools/web_browser_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import importlib
from typing import Any, Callable, Coroutine, Dict, Literal, overload
from typing import Any, Callable, Coroutine, Literal, overload

from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
Expand All @@ -16,7 +16,6 @@
class WebBrowserEngine:
def __init__(
self,
options: Dict,
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
Expand Down
2 changes: 1 addition & 1 deletion metagpt/utils/get_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from metagpt.config import CONFIG


def get_template(templates, schema=CONFIG.prompt_schema):
def get_template(templates, schema=CONFIG.prompt_format):
selected_templates = templates.get(schema)
if selected_templates is None:
raise ValueError(f"Can't find {schema} in passed in templates")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ websockets~=12.0
networkx~=3.2.1
pylint~=3.0.3
google-generativeai==0.3.1
playwright==1.40.0

0 comments on commit 8d1a3ce

Please sign in to comment.