Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace CONTEXT with Context() #784

Merged
merged 8 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metagpt/actions/write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def run(self, *args, **kwargs) -> CodingContext:
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.i_context.code_doc.content)={len2}"
Expand Down
1 change: 1 addition & 0 deletions metagpt/config2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def check_project_path(self):
if self.project_path:
self.inc = True
self.project_name = self.project_name or Path(self.project_path).name
return self


class Config(CLIParams, YamlModel):
Expand Down
4 changes: 0 additions & 4 deletions metagpt/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,3 @@ def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLL
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm


# Global context, not in Env
CONTEXT = Context()
6 changes: 3 additions & 3 deletions metagpt/context_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel, ConfigDict, Field

from metagpt.config2 import Config
from metagpt.context import CONTEXT, Context
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM


Expand All @@ -34,7 +34,7 @@ class ContextMixin(BaseModel):

def __init__(
self,
context: Optional[Context] = CONTEXT,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
Expand Down Expand Up @@ -81,7 +81,7 @@ def context(self) -> Context:
"""Role context: role context > context"""
if self.private_context:
return self.private_context
return CONTEXT
return Context()

@context.setter
def context(self, context: Context) -> None:
Expand Down
7 changes: 4 additions & 3 deletions metagpt/learn/skill_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml
from pydantic import BaseModel, Field

from metagpt.context import CONTEXT, Context
from metagpt.context import Context


class Example(BaseModel):
Expand Down Expand Up @@ -73,14 +73,15 @@ async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)

def get_skill_list(self, entity_name: str = "Assistant", context: Context = CONTEXT) -> Dict:
def get_skill_list(self, entity_name: str = "Assistant", context: Context = None) -> Dict:
"""Return the skill name based on the skill description."""
entity = self.entities.get(entity_name)
if not entity:
return {}

# List of skills that the agent chooses to activate.
agent_skills = context.kwargs.agent_skills
ctx = context or Context()
agent_skills = ctx.kwargs.agent_skills
if not agent_skills:
return {}

Expand Down
9 changes: 5 additions & 4 deletions metagpt/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing import Optional

from metagpt.configs.llm_config import LLMConfig
from metagpt.context import CONTEXT
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM


def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM:
def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM:
"""get the default llm provider if name is None"""
ctx = context or Context()
if llm_config is not None:
CONTEXT.llm_with_cost_manager_from_llm_config(llm_config)
return CONTEXT.llm()
ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm()
3 changes: 1 addition & 2 deletions metagpt/roles/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
from metagpt.actions.talk_action import TalkAction
from metagpt.context import CONTEXT
from metagpt.learn.skill_loader import SkillsDeclaration
from metagpt.logs import logger
from metagpt.memory.brain_memory import BrainMemory
Expand All @@ -48,7 +47,7 @@ class Assistant(Role):

def __init__(self, **kwargs):
super().__init__(**kwargs)
language = kwargs.get("language") or self.context.kwargs.language or CONTEXT.kwargs.language
language = kwargs.get("language") or self.context.kwargs.language
self.constraints = self.constraints.format(language=language)

async def think(self) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions metagpt/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from metagpt.config2 import config
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.context import Context

app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)

Expand Down Expand Up @@ -37,9 +38,10 @@ def generate_repo(
from metagpt.team import Team

config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
ctx = Context(config=config)

if not recover_path:
company = Team()
company = Team(context=ctx)
company.hire(
[
ProductManager(),
Expand All @@ -58,7 +60,7 @@ def generate_repo(
if not stg_path.exists() or not str(stg_path).endswith("team"):
raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`")

company = Team.deserialize(stg_path=stg_path)
company = Team.deserialize(stg_path=stg_path, context=ctx)
idea = company.idea

company.invest(investment)
Expand Down
17 changes: 12 additions & 5 deletions metagpt/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

import warnings
from pathlib import Path
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field

from metagpt.actions import UserRequirement
from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH
from metagpt.context import Context
from metagpt.environment import Environment
from metagpt.logs import logger
from metagpt.roles import Role
Expand All @@ -36,12 +37,17 @@ class Team(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

env: Environment = Field(default_factory=Environment)
env: Optional[Environment] = None
investment: float = Field(default=10.0)
idea: str = Field(default="")

def __init__(self, **data: Any):
def __init__(self, context: Context = None, **data: Any):
super(Team, self).__init__(**data)
ctx = context or Context()
if not self.env:
self.env = Environment(context=ctx)
else:
self.env.context = ctx # The `env` object is allocated by deserialization
if "roles" in data:
self.hire(data["roles"])
if "env_desc" in data:
Expand All @@ -54,7 +60,7 @@ def serialize(self, stg_path: Path = None):
write_json_file(team_info_path, self.model_dump())

@classmethod
def deserialize(cls, stg_path: Path) -> "Team":
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
"""stg_path = ./storage/team"""
# recover team_info
team_info_path = stg_path.joinpath("team.json")
Expand All @@ -64,7 +70,8 @@ def deserialize(cls, stg_path: Path) -> "Team":
)

team_info: dict = read_json_file(team_info_path)
team = Team(**team_info)
ctx = context or Context()
team = Team(**team_info, context=ctx)
return team

def hire(self, roles: list[Role]):
Expand Down
10 changes: 8 additions & 2 deletions metagpt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import aiofiles
import loguru
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from tenacity import RetryCallState, RetryError, _utils

from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
Expand Down Expand Up @@ -505,7 +505,7 @@ async def wrapper(self, *args, **kwargs):
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
except Exception as e:
if self.latest_observed_msg:
logger.warning(
"There is a exception in role's execution, in order to resume, "
Expand All @@ -514,6 +514,12 @@ async def wrapper(self, *args, **kwargs):
# remove role newest observed msg to make it observed again
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
if isinstance(e, RetryError):
last_error = e.last_attempt._exception
name = any_to_str(last_error)
if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name):
raise last_error

raise Exception(format_trackback_info(limit=None))

return wrapper
Expand Down
Binary file added tests/data/audio/hello.mp3
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/metagpt/roles/test_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_parse_code():

def test_todo():
role = Engineer()
assert role.todo == any_to_name(WriteCode)
assert role.action_description == any_to_name(WriteCode)


@pytest.mark.asyncio
Expand Down
4 changes: 4 additions & 0 deletions tests/metagpt/serialize_deserialize/test_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,7 @@ async def test_team_recover_multi_roles_save(mocker, context):
assert new_company.env.get_role(role_b.profile).rc.state == 1

await new_company.run(n_round=4)


if __name__ == "__main__":
pytest.main([__file__, "-s"])
7 changes: 4 additions & 3 deletions tests/metagpt/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@File : test_context.py
"""
from metagpt.configs.llm_config import LLMType
from metagpt.context import CONTEXT, AttrDict, Context
from metagpt.context import AttrDict, Context


def test_attr_dict_1():
Expand Down Expand Up @@ -51,11 +51,12 @@ def test_context_1():


def test_context_2():
llm = CONTEXT.config.get_openai_llm()
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI

kwargs = CONTEXT.kwargs
kwargs = ctx.kwargs
assert kwargs is not None

kwargs.test_key = "test_value"
Expand Down
7 changes: 3 additions & 4 deletions tests/metagpt/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pytest

from metagpt.actions import UserRequirement
from metagpt.context import CONTEXT
from metagpt.environment import Environment
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, Role
Expand Down Expand Up @@ -44,9 +43,9 @@ def test_get_roles(env: Environment):

@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
if CONTEXT.git_repo:
CONTEXT.git_repo.delete_repository()
CONTEXT.git_repo = None
if env.context.git_repo:
env.context.git_repo.delete_repository()
env.context.git_repo = None

product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限")
architect = Architect(
Expand Down
2 changes: 1 addition & 1 deletion tests/metagpt/test_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def test_recover():
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
assert role.first_action == any_to_name(MockAction)
assert role.action_description == any_to_name(MockAction)

rsp = await role.run()
assert rsp.cause_by == any_to_str(MockAction)
Expand Down
Loading