diff --git a/.gitignore b/.gitignore
index 6dd3608f1..c10191dcc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -172,6 +172,8 @@ tests/metagpt/utils/file_repo_git
*.png
htmlcov
htmlcov.*
+*.dot
*.pkl
*-structure.csv
*-structure.json
+
diff --git a/docs/.well-known/openapi.yaml b/docs/.well-known/openapi.yaml
index bc291b7db..47ca04b23 100644
--- a/docs/.well-known/openapi.yaml
+++ b/docs/.well-known/openapi.yaml
@@ -11,7 +11,7 @@ paths:
post:
summary: Generate greeting
description: Generates a greeting message.
- operationId: hello.post_greeting
+ operationId: openapi_v3_hello.post_greeting
responses:
200:
description: greeting response
diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py
index 66bc2c7ab..dbc11d14b 100644
--- a/metagpt/actions/rebuild_class_view.py
+++ b/metagpt/actions/rebuild_class_view.py
@@ -9,60 +9,187 @@
import re
from pathlib import Path
+import aiofiles
+
from metagpt.actions import Action
from metagpt.config import CONFIG
-from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
+from metagpt.const import (
+ AGGREGATION,
+ COMPOSITION,
+ DATA_API_DESIGN_FILE_REPO,
+ GENERALIZATION,
+ GRAPH_REPO_FILE_REPO,
+)
+from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
+from metagpt.schema import ClassAttribute, ClassMethod, ClassView
+from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
- def __init__(self, name="", context=None, llm=None):
- super().__init__(name=name, context=context, llm=llm)
-
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
- repo_parser = RepoParser(base_directory=self.context)
- class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
+ repo_parser = RepoParser(base_directory=Path(self.context))
+ class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
+ await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
- await self._create_mermaid_class_view(graph_db=graph_db)
- await self._save(graph_db=graph_db)
-
- async def _create_mermaid_class_view(self, graph_db):
- pass
- # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
- # if not dataset:
- # logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
- # return
- # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
- # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
- # code_type = ""
- # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
- # for spo in dataset:
- # if spo.object_ in ["javascript", "python"]:
- # code_type = spo.object_
- # break
-
- # try:
- # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
- # class_view = node.instruct_content.model_dump()["Class View"]
- # except Exception as e:
- # class_view = RepoParser.rebuild_class_view(src_code, code_type)
- # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
- # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
-
- async def _save(self, graph_db):
- class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
- dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
- all_class_view = []
- for spo in dataset:
- title = f"---\ntitle: {spo.subject}\n---\n"
- filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
- await class_view_file_repo.save(filename=filename, content=title + spo.object_)
- all_class_view.append(spo.object_)
- await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
+ await self._create_mermaid_class_views(graph_db=graph_db)
+ await graph_db.save()
+
+ async def _create_mermaid_class_views(self, graph_db):
+ path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
+ path.mkdir(parents=True, exist_ok=True)
+ pathname = path / CONFIG.git_repo.workdir.name
+ async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
+ content = "classDiagram\n"
+ logger.debug(content)
+ await writer.write(content)
+ # class names
+ rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
+ class_distinct = set()
+ relationship_distinct = set()
+ for r in rows:
+ await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
+ for r in rows:
+ await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
+
+ @staticmethod
+ async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
+ fields = split_namespace(ns_class_name)
+ if len(fields) > 2:
+ # Ignore sub-class
+ return
+
+ class_view = ClassView(name=fields[1])
+ rows = await graph_db.select(subject=ns_class_name)
+ for r in rows:
+ name = split_namespace(r.object_)[-1]
+ name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
+ if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
+ var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
+ attribute = ClassAttribute(
+ name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
+ )
+ class_view.attributes.append(attribute)
+ elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
+ method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
+ await RebuildClassView._parse_function_args(method, r.object_, graph_db)
+ class_view.methods.append(method)
+
+ # update graph db
+ await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
+
+ content = class_view.get_mermaid(align=1)
+ logger.debug(content)
+ await file_writer.write(content)
+ distinct.add(ns_class_name)
+
+ @staticmethod
+ async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
+ s_fields = split_namespace(ns_class_name)
+ if len(s_fields) > 2:
+ # Ignore sub-class
+ return
+
+ predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
+ mappings = {
+ GENERALIZATION: " <|-- ",
+ COMPOSITION: " *-- ",
+ AGGREGATION: " o-- ",
+ }
+ content = ""
+ for p, v in predicates.items():
+ rows = await graph_db.select(subject=ns_class_name, predicate=p)
+ for r in rows:
+ o_fields = split_namespace(r.object_)
+ if len(o_fields) > 2:
+ # Ignore sub-class
+ continue
+ relationship = mappings.get(v, " .. ")
+ link = f"{o_fields[1]}{relationship}{s_fields[1]}"
+ distinct.add(link)
+ content += f"\t{link}\n"
+
+ if content:
+ logger.debug(content)
+ await file_writer.write(content)
+
+ @staticmethod
+ def _parse_name(name: str, language="python"):
+ pattern = re.compile(r"(.*?)<\/I>")
+ result = re.search(pattern, name)
+
+ abstraction = ""
+ if result:
+ name = result.group(1)
+ abstraction = "*"
+ if name.startswith("__"):
+ visibility = "-"
+ elif name.startswith("_"):
+ visibility = "#"
+ else:
+ visibility = "+"
+ return name, visibility, abstraction
+
+ @staticmethod
+ async def _parse_variable_type(ns_name, graph_db) -> str:
+ rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
+ if not rows:
+ return ""
+ vals = rows[0].object_.replace("'", "").split(":")
+ if len(vals) == 1:
+ return ""
+ val = vals[-1].strip()
+ return "" if val == "NoneType" else val + " "
+
+ @staticmethod
+ async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
+ rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
+ if not rows:
+ return
+ info = rows[0].object_.replace("'", "")
+
+ fs_tag = "("
+ ix = info.find(fs_tag)
+ fe_tag = "):"
+ eix = info.rfind(fe_tag)
+ if eix < 0:
+ fe_tag = ")"
+ eix = info.rfind(fe_tag)
+ args_info = info[ix + len(fs_tag) : eix].strip()
+ method.return_type = info[eix + len(fe_tag) :].strip()
+ if method.return_type == "None":
+ method.return_type = ""
+ if "(" in method.return_type:
+ method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
+
+ # parse args
+ if not args_info:
+ return
+ splitter_ixs = []
+ cost = 0
+ for i in range(len(args_info)):
+ if args_info[i] == "[":
+ cost += 1
+ elif args_info[i] == "]":
+ cost -= 1
+ if args_info[i] == "," and cost == 0:
+ splitter_ixs.append(i)
+ splitter_ixs.append(len(args_info))
+ args = []
+ ix = 0
+ for eix in splitter_ixs:
+ args.append(args_info[ix:eix])
+ ix = eix + 1
+ for arg in args:
+ parts = arg.strip().split(":")
+ if len(parts) == 1:
+ method.args.append(ClassAttribute(name=parts[0].strip()))
+ continue
+ method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py
index 25c4912c3..7377442b5 100644
--- a/metagpt/actions/write_code.py
+++ b/metagpt/actions/write_code.py
@@ -130,7 +130,7 @@ async def run(self, *args, **kwargs) -> CodingContext:
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
- coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
+ coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context
diff --git a/metagpt/const.py b/metagpt/const.py
index a57be641b..811ff9516 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -126,3 +126,8 @@ def get_metagpt_root():
# Message id
IGNORED_MESSAGE_ID = "0"
+
+# Class Relationship
+GENERALIZATION = "Generalize"
+COMPOSITION = "Composite"
+AGGREGATION = "Aggregate"
diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py
index 5e4d67940..9863a29ae 100644
--- a/metagpt/repo_parser.py
+++ b/metagpt/repo_parser.py
@@ -12,14 +12,14 @@
import re
import subprocess
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
-import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
+from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
-from metagpt.utils.common import any_to_str
+from metagpt.utils.common import any_to_str, aread
from metagpt.utils.exceptions import handle_exception
@@ -46,6 +46,13 @@ class ClassInfo(BaseModel):
methods: Dict[str, str] = Field(default_factory=dict)
+class ClassRelationship(BaseModel):
+ src: str = ""
+ dest: str = ""
+ relationship: str = ""
+ label: Optional[str] = None
+
+
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@@ -60,7 +67,8 @@ def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
- file_info.page_info.append(info)
+ if info:
+ file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info.classes.append({"name": node.name, "methods": class_methods})
@@ -110,7 +118,9 @@ def generate_structure(self, output_path=None, mode="json") -> Path:
return output_path
@staticmethod
- def node_to_str(node) -> (int, int, str, str | Tuple):
+ def node_to_str(node) -> CodeBlockInfo | None:
+ if isinstance(node, ast.Try):
+ return None
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
@@ -129,6 +139,7 @@ def node_to_str(node) -> (int, int, str, str | Tuple):
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
+ any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
}
func = mappings.get(any_to_str(node))
if func:
@@ -143,7 +154,8 @@ def node_to_str(node) -> (int, int, str, str | Tuple):
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
- raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
+ logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
+ return None
@staticmethod
def _parse_expr(node) -> List:
@@ -164,22 +176,51 @@ def _parse_name(n):
@staticmethod
def _parse_if(n):
- tokens = [RepoParser._parse_variable(n.test.left)]
- for item in n.test.comparators:
- tokens.append(RepoParser._parse_variable(item))
+ tokens = []
+ try:
+ if isinstance(n.test, ast.BoolOp):
+ tokens = []
+ for v in n.test.values:
+ tokens.extend(RepoParser._parse_if_compare(v))
+ return tokens
+ if isinstance(n.test, ast.Compare):
+ v = RepoParser._parse_variable(n.test.left)
+ if v:
+ tokens.append(v)
+ for item in n.test.comparators:
+ v = RepoParser._parse_variable(item)
+ if v:
+ tokens.append(v)
+ return tokens
+ except Exception as e:
+ logger.warning(f"Unsupported if: {n}, err:{e}")
return tokens
+ @staticmethod
+ def _parse_if_compare(n):
+ if hasattr(n, "left"):
+ return RepoParser._parse_variable(n.left)
+ else:
+ return []
+
@staticmethod
def _parse_variable(node):
- funcs = {
- any_to_str(ast.Constant): lambda x: x.value,
- any_to_str(ast.Name): lambda x: x.id,
- any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
- }
- func = funcs.get(any_to_str(node))
- if not func:
- raise NotImplementedError(f"Not implement:{node}")
- return func(node)
+ try:
+ funcs = {
+ any_to_str(ast.Constant): lambda x: x.value,
+ any_to_str(ast.Name): lambda x: x.id,
+ any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
+ if hasattr(x.value, "id")
+ else f"{x.attr}",
+ any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
+ any_to_str(ast.Tuple): lambda x: "",
+ }
+ func = funcs.get(any_to_str(node))
+ if not func:
+ raise NotImplementedError(f"Not implement:{node}")
+ return func(node)
+ except Exception as e:
+ logger.warning(f"Unsupported variable:{node}, err:{e}")
@staticmethod
def _parse_assign(node):
@@ -197,18 +238,21 @@ async def rebuild_class_views(self, path: str | Path = None):
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
+ relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
- class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
+ class_views, relationship_views = RepoParser._repair_namespaces(
+ class_views=class_views, relationship_views=relationship_views, path=path
+ )
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
- return class_views
+ return class_views, relationship_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
- async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
- lines = await reader.readlines()
+ data = await aread(filename=class_view_pathname, encoding="utf-8")
+ lines = data.split("\n")
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
@@ -229,6 +273,19 @@ async def _parse_classes(self, class_view_pathname):
class_views.append(class_info)
return class_views
+ async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
+ relationship_views = []
+ if not class_view_pathname.exists():
+ return relationship_views
+ data = await aread(filename=class_view_pathname, encoding="utf-8")
+ lines = data.split("\n")
+ for line in lines:
+ relationship = RepoParser._split_relationship_line(line)
+ if not relationship:
+ continue
+ relationship_views.append(relationship)
+ return relationship_views
+
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
@@ -247,6 +304,40 @@ def _split_class_line(line):
info = re.sub(r"
]*>", "\n", info)
return class_name, info
+ @staticmethod
+ def _split_relationship_line(line):
+ splitters = [" -> ", " [", "];"]
+ idxs = []
+ for tag in splitters:
+ if tag not in line:
+ return None
+ idxs.append(line.find(tag))
+ ret = ClassRelationship()
+ ret.src = line[0 : idxs[0]].strip('"')
+ ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
+ properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
+ mappings = {
+ 'arrowhead="empty"': GENERALIZATION,
+ 'arrowhead="diamond"': COMPOSITION,
+ 'arrowhead="odiamond"': AGGREGATION,
+ }
+ for k, v in mappings.items():
+ if k in properties:
+ ret.relationship = v
+ if v != GENERALIZATION:
+ ret.label = RepoParser._get_label(properties)
+ break
+ return ret
+
+ @staticmethod
+ def _get_label(line):
+ tag = 'label="'
+ if tag not in line:
+ return ""
+ ix = line.find(tag)
+ eix = line.find('"', ix + len(tag))
+ return line[ix + len(tag) : eix]
+
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
@@ -271,7 +362,9 @@ def _create_path_mapping(path: str | Path) -> Dict[str, str]:
return mappings
@staticmethod
- def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
+ def _repair_namespaces(
+ class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
+ ) -> (List[ClassInfo], List[ClassRelationship]):
if not class_views:
return []
c = class_views[0]
@@ -290,7 +383,12 @@ def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[C
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
- return class_views
+ for i in range(len(relationship_views)):
+ v = relationship_views[i]
+ v.src = RepoParser._repair_ns(v.src, new_mappings)
+ v.dest = RepoParser._repair_ns(v.dest, new_mappings)
+ relationship_views[i] = v
+ return class_views, relationship_views
@staticmethod
def _repair_ns(package, mappings):
diff --git a/metagpt/schema.py b/metagpt/schema.py
index e36bef395..02d44f767 100644
--- a/metagpt/schema.py
+++ b/metagpt/schema.py
@@ -451,3 +451,63 @@ def __hash__(self):
class BugFixContext(BaseContext):
filename: str = ""
+
+
+# mermaid class view
+class ClassMeta(BaseModel):
+ name: str = ""
+ abstraction: bool = False
+ static: bool = False
+ visibility: str = ""
+
+
+class ClassAttribute(ClassMeta):
+ value_type: str = ""
+ default_value: str = ""
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + self.visibility
+ if self.value_type:
+ content += self.value_type + " "
+ content += self.name
+ if self.default_value:
+ content += "="
+ if self.value_type not in ["str", "string", "String"]:
+ content += self.default_value
+ else:
+ content += '"' + self.default_value.replace('"', "") + '"'
+ if self.abstraction:
+ content += "*"
+ if self.static:
+ content += "$"
+ return content
+
+
+class ClassMethod(ClassMeta):
+ args: List[ClassAttribute] = Field(default_factory=list)
+ return_type: str = ""
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + self.visibility
+ content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
+ if self.return_type:
+ content += ":" + self.return_type
+ if self.abstraction:
+ content += "*"
+ if self.static:
+ content += "$"
+ return content
+
+
+class ClassView(ClassMeta):
+ attributes: List[ClassAttribute] = Field(default_factory=list)
+ methods: List[ClassMethod] = Field(default_factory=list)
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
+ for v in self.attributes:
+ content += v.get_mermaid(align=align + 1) + "\n"
+ for v in self.methods:
+ content += v.get_mermaid(align=align + 1) + "\n"
+ content += "".join(["\t" for i in range(align)]) + "}\n"
+ return content
diff --git a/metagpt/tools/metagpt_oas3_api_svc.py b/metagpt/tools/metagpt_oas3_api_svc.py
index 319e7efb2..8e9f4a0da 100644
--- a/metagpt/tools/metagpt_oas3_api_svc.py
+++ b/metagpt/tools/metagpt_oas3_api_svc.py
@@ -5,6 +5,12 @@
@Author : mashenquan
@File : metagpt_oas3_api_svc.py
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
+
+ curl -X 'POST' \
+ 'http://localhost:8080/openapi/greeting/dave' \
+ -H 'accept: text/plain' \
+ -H 'Content-Type: application/json' \
+ -d '{}'
"""
from pathlib import Path
@@ -15,7 +21,7 @@
def oas_http_svc():
"""Start the OAS 3.0 OpenAPI HTTP service"""
print("http://localhost:8080/oas3/ui/")
- specification_dir = Path(__file__).parent.parent.parent / ".well-known"
+ specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("metagpt_oas3_api.yaml")
app.add_api("openapi.yaml")
diff --git a/metagpt/tools/openapi_v3_hello.py b/metagpt/tools/openapi_v3_hello.py
index c8f5de42d..d1c83eac2 100644
--- a/metagpt/tools/openapi_v3_hello.py
+++ b/metagpt/tools/openapi_v3_hello.py
@@ -23,7 +23,7 @@ async def post_greeting(name: str) -> str:
if __name__ == "__main__":
- specification_dir = Path(__file__).parent.parent.parent / ".well-known"
+ specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
app.run(port=8082)
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index c7751c2af..0032f0b0d 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -407,6 +407,10 @@ def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
+def split_namespace(ns_class_name: str) -> List[str]:
+ return ns_class_name.split(":")
+
+
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py
index 08f4327fa..8bb5f9bb3 100644
--- a/metagpt/utils/di_graph_repository.py
+++ b/metagpt/utils/di_graph_repository.py
@@ -12,9 +12,9 @@
from pathlib import Path
from typing import List
-import aiofiles
import networkx
+from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
@@ -55,12 +55,10 @@ async def save(self, path: str | Path = None):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
- async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
- await writer.write(data)
+ await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
- async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
- data = await reader.read(-1)
+ data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py
index ff750fbbb..0ddca414d 100644
--- a/metagpt/utils/file_repository.py
+++ b/metagpt/utils/file_repository.py
@@ -138,6 +138,8 @@ def changed_files(self) -> Dict[str, str]:
files = self._git_repo.changed_files
relative_files = {}
for p, ct in files.items():
+ if ct.value == "D": # deleted
+ continue
try:
rf = Path(p).relative_to(self._relative_path)
except ValueError:
diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py
index 37da3dee4..88946c98e 100644
--- a/metagpt/utils/graph_repository.py
+++ b/metagpt/utils/graph_repository.py
@@ -13,19 +13,25 @@
from pydantic import BaseModel
-from metagpt.repo_parser import ClassInfo, RepoFileInfo
+from metagpt.logs import logger
+from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
+ OF = "Of"
+ ON = "On"
CLASS = "class"
FUNCTION = "function"
+ HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = ""
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
+ HAS_CLASS_FUNCTION = "has_class_function"
+ HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
@@ -73,11 +79,13 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info:
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
+ # file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
+ # class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
@@ -85,12 +93,22 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info:
)
methods = c.get("methods", [])
for fn in methods:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, class_name),
+ predicate=GraphKeyword.HAS_CLASS_FUNCTION,
+ object_=concat_namespace(file_info.file, class_name, fn),
+ )
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
+ # file -> function
+ await graph_db.insert(
+ subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
+ )
+ # function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
@@ -105,13 +123,13 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info:
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
- object_=code_block.json(ensure_ascii=False),
+ object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
- object_=code_block.json(ensure_ascii=False),
+ object_=code_block.model_dump_json(),
)
@staticmethod
@@ -129,6 +147,13 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
+ # class -> property
+ await graph_db.insert(
+ subject=c.package,
+ predicate=GraphKeyword.HAS_CLASS_PROPERTY,
+ object_=concat_namespace(c.package, vn),
+ )
+ # property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
@@ -138,6 +163,15 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
+ if "" in desc and "" not in desc:
+ logger.error(desc)
+ # class -> function
+ await graph_db.insert(
+ subject=c.package,
+ predicate=GraphKeyword.HAS_CLASS_FUNCTION,
+ object_=concat_namespace(c.package, fn),
+ )
+ # function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
@@ -148,3 +182,19 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
+
+ @staticmethod
+ async def update_graph_db_with_class_relationship_views(
+ graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
+ ):
+ for r in relationship_views:
+ await graph_db.insert(
+ subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
+ )
+ if not r.label:
+ continue
+ await graph_db.insert(
+ subject=r.src,
+ predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
+ object_=concat_namespace(r.dest, r.label),
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
index fbf9ff465..a15e3e85b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,6 +11,7 @@
import logging
import os
import re
+import uuid
from typing import Optional
import pytest
@@ -151,9 +152,9 @@ def emit(self, record):
# init & dispose git repo
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown_git_repo(request):
- CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
+ CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.
diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py
index 955c6ae3b..0103e9d05 100644
--- a/tests/metagpt/actions/test_rebuild_class_view.py
+++ b/tests/metagpt/actions/test_rebuild_class_view.py
@@ -11,13 +11,19 @@
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
+from metagpt.config import CONFIG
+from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
- action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
+ action = RebuildClassView(
+ name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
+ )
await action.run()
+ graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
+ assert graph_file_repo.changed_files
if __name__ == "__main__":
diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py
index 816c186e2..b6e334fbe 100644
--- a/tests/metagpt/test_schema.py
+++ b/tests/metagpt/test_schema.py
@@ -19,6 +19,9 @@
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
+ ClassAttribute,
+ ClassMethod,
+ ClassView,
CodeSummarizeContext,
Document,
Message,
@@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want):
assert want in m
+def test_class_view():
+ attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
+ assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
+ attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
+ assert attr_b.get_mermaid(align=0) == '#str b="0"$'
+ class_view = ClassView(name="A")
+ class_view.attributes = [attr_a, attr_b]
+
+ method_a = ClassMethod(name="run", visibility="+", abstraction=True)
+ assert method_a.get_mermaid(align=1) == "\t+run()*"
+ method_b = ClassMethod(
+ name="_test",
+ visibility="#",
+ static=True,
+ args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
+ return_type="str",
+ )
+ assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
+ class_view.methods = [method_a, method_b]
+ assert (
+ class_view.get_mermaid(align=0)
+ == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py
index 1135860eb..5f52b28cc 100644
--- a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py
+++ b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py
@@ -24,13 +24,14 @@ async def test_oas2_svc():
process = subprocess.Popen(["python", str(script_pathname)], cwd=str(workdir), env=env)
await asyncio.sleep(5)
- url = "http://localhost:8080/openapi/greeting/dave"
- headers = {"accept": "text/plain", "Content-Type": "application/json"}
- data = {}
- response = requests.post(url, headers=headers, json=data)
- assert response.text == "Hello dave\n"
-
- process.terminate()
+ try:
+ url = "http://localhost:8080/openapi/greeting/dave"
+ headers = {"accept": "text/plain", "Content-Type": "application/json"}
+ data = {}
+ response = requests.post(url, headers=headers, json=data)
+ assert response.text == "Hello dave\n"
+ finally:
+ process.terminate()
if __name__ == "__main__":
diff --git a/tests/metagpt/tools/test_hello.py b/tests/metagpt/tools/test_openapi_v3_hello.py
similarity index 65%
rename from tests/metagpt/tools/test_hello.py
rename to tests/metagpt/tools/test_openapi_v3_hello.py
index 7e61532ab..5726cf8e0 100644
--- a/tests/metagpt/tools/test_hello.py
+++ b/tests/metagpt/tools/test_openapi_v3_hello.py
@@ -3,7 +3,7 @@
"""
@Time : 2023/12/26
@Author : mashenquan
-@File : test_hello.py
+@File : test_openapi_v3_hello.py
"""
import asyncio
import subprocess
@@ -24,13 +24,14 @@ async def test_hello():
process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env)
await asyncio.sleep(5)
- url = "http://localhost:8082/openapi/greeting/dave"
- headers = {"accept": "text/plain", "Content-Type": "application/json"}
- data = {}
- response = requests.post(url, headers=headers, json=data)
- assert response.text == "Hello dave\n"
-
- process.terminate()
+ try:
+ url = "http://localhost:8082/openapi/greeting/dave"
+ headers = {"accept": "text/plain", "Content-Type": "application/json"}
+ data = {}
+ response = requests.post(url, headers=headers, json=data)
+ assert response.text == "Hello dave\n"
+ finally:
+ process.terminate()
if __name__ == "__main__":
diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py
index 0342a92af..9b1fa878e 100644
--- a/tests/metagpt/utils/test_common.py
+++ b/tests/metagpt/utils/test_common.py
@@ -36,6 +36,7 @@
read_file_block,
read_json_file,
require_python_version,
+ split_namespace,
)
@@ -163,6 +164,23 @@ def test_concat_namespace(self):
assert concat_namespace("a", "b", "c", "e") == "a:b:c:e"
assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f"
+ @pytest.mark.parametrize(
+ ("val", "want"),
+ [
+ (
+ "tests/metagpt/test_role.py:test_react:Input:subscription",
+ ["tests/metagpt/test_role.py", "test_react", "Input", "subscription"],
+ ),
+ (
+ "tests/metagpt/test_role.py:test_react:Input:goal",
+ ["tests/metagpt/test_role.py", "test_react", "Input", "goal"],
+ ),
+ ],
+ )
+ def test_split_namespace(self, val, want):
+ res = split_namespace(val)
+ assert res == want
+
def test_read_json_file(self):
assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8")
with pytest.raises(FileNotFoundError):
diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py
index b93ff0cdb..d499418ac 100644
--- a/tests/metagpt/utils/test_redis.py
+++ b/tests/metagpt/utils/test_redis.py
@@ -6,20 +6,34 @@
@File : test_redis.py
"""
+import mock
import pytest
from metagpt.config import CONFIG
from metagpt.utils.redis import Redis
+async def async_mock_from_url(*args, **kwargs):
+ mock_client = mock.AsyncMock()
+ mock_client.set.return_value = None
+ mock_client.get.side_effect = [b"test", b""]
+ return mock_client
+
+
@pytest.mark.asyncio
-async def test_redis():
+@mock.patch("aioredis.from_url", return_value=async_mock_from_url())
+async def test_redis(mock_from_url):
+ # Mock
+ # mock_client = mock.AsyncMock()
+ # mock_client.set.return_value=None
+ # mock_client.get.side_effect = [b'test', b'']
+ # mock_from_url.return_value = mock_client
+
# Prerequisites
- assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
- assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
- # assert CONFIG.REDIS_USER
- assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD"
- assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based"
+ CONFIG.REDIS_HOST = "MOCK_REDIS_HOST"
+ CONFIG.REDIS_PORT = "MOCK_REDIS_PORT"
+ CONFIG.REDIS_PASSWORD = "MOCK_REDIS_PASSWORD"
+ CONFIG.REDIS_DB = 0
conn = Redis()
assert not conn.is_valid
diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py
index f74e7b52a..132aa0635 100644
--- a/tests/metagpt/utils/test_s3.py
+++ b/tests/metagpt/utils/test_s3.py
@@ -9,20 +9,36 @@
from pathlib import Path
import aiofiles
+import mock
import pytest
from metagpt.config import CONFIG
+from metagpt.utils.common import aread
from metagpt.utils.s3 import S3
@pytest.mark.asyncio
-async def test_s3():
+@mock.patch("aioboto3.Session")
+async def test_s3(mock_session_class):
+ # Set up the mock response
+ data = await aread(__file__, "utf-8")
+ mock_session_object = mock.Mock()
+ reader_mock = mock.AsyncMock()
+ reader_mock.read.side_effect = [data.encode("utf-8"), b"", data.encode("utf-8")]
+ type(reader_mock).url = mock.PropertyMock(return_value="https://mock")
+ mock_client = mock.AsyncMock()
+ mock_client.put_object.return_value = None
+ mock_client.get_object.return_value = {"Body": reader_mock}
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__.return_value = None
+ mock_session_object.client.return_value = mock_client
+ mock_session_class.return_value = mock_session_object
+
# Prerequisites
- assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
- assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
- assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
- # assert CONFIG.S3_SECURE: true # true/false
- assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
+ # assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
+ # assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
+ # assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
+ # assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
conn = S3()
assert conn.is_valid
@@ -42,6 +58,7 @@ async def test_s3():
assert "http" in res
# Mock session env
+ type(reader_mock).url = mock.PropertyMock(return_value="")
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
@@ -54,6 +71,8 @@ async def test_s3():
finally:
CONFIG.set_context(old_options)
+ await reader.close()
+
if __name__ == "__main__":
pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/utils/test_session.py b/tests/metagpt/utils/test_session.py
new file mode 100644
index 000000000..eab2587a2
--- /dev/null
+++ b/tests/metagpt/utils/test_session.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# _*_ coding: utf-8 _*_
+
+import pytest
+
+
+def test_nodeid(request):
+ print(request.node.nodeid)
+ assert request.node.nodeid
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-s"])