diff --git a/.gitignore b/.gitignore index 6dd3608f1..c10191dcc 100644 --- a/.gitignore +++ b/.gitignore @@ -172,6 +172,8 @@ tests/metagpt/utils/file_repo_git *.png htmlcov htmlcov.* +*.dot *.pkl *-structure.csv *-structure.json + diff --git a/docs/.well-known/openapi.yaml b/docs/.well-known/openapi.yaml index bc291b7db..47ca04b23 100644 --- a/docs/.well-known/openapi.yaml +++ b/docs/.well-known/openapi.yaml @@ -11,7 +11,7 @@ paths: post: summary: Generate greeting description: Generates a greeting message. - operationId: hello.post_greeting + operationId: openapi_v3_hello.post_greeting responses: 200: description: greeting response diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 66bc2c7ab..dbc11d14b 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -9,60 +9,187 @@ import re from pathlib import Path +import aiofiles + from metagpt.actions import Action from metagpt.config import CONFIG -from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO +from metagpt.const import ( + AGGREGATION, + COMPOSITION, + DATA_API_DESIGN_FILE_REPO, + GENERALIZATION, + GRAPH_REPO_FILE_REPO, +) +from metagpt.logs import logger from metagpt.repo_parser import RepoParser +from metagpt.schema import ClassAttribute, ClassMethod, ClassView +from metagpt.utils.common import split_namespace from metagpt.utils.di_graph_repository import DiGraphRepository from metagpt.utils.graph_repository import GraphKeyword, GraphRepository class RebuildClassView(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name=name, context=context, llm=llm) - async def run(self, with_messages=None, format=CONFIG.prompt_schema): graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) - repo_parser = RepoParser(base_directory=self.context) - class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint + repo_parser = RepoParser(base_directory=Path(self.context)) + class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) + await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views) symbols = repo_parser.generate_symbols() # use ast for file_info in symbols: await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) - await self._create_mermaid_class_view(graph_db=graph_db) - await self._save(graph_db=graph_db) - - async def _create_mermaid_class_view(self, graph_db): - pass - # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO) - # if not dataset: - # logger.warning(f"No page info for {concat_namespace(filename, class_name)}") - # return - # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_) - # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno) - # code_type = "" - # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS) - # for spo in dataset: - # if spo.object_ in ["javascript", "python"]: - # code_type = spo.object_ - # break - - # try: - # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format) - # class_view = node.instruct_content.model_dump()["Class View"] - # except Exception as e: - # class_view = RepoParser.rebuild_class_view(src_code, code_type) - # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view) - # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}") - - async def _save(self, graph_db): - class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) - dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW) - all_class_view = [] - for spo in dataset: - title = f"---\ntitle: {spo.subject}\n---\n" - filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd" - await class_view_file_repo.save(filename=filename, content=title + spo.object_) - all_class_view.append(spo.object_) - await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view)) + await self._create_mermaid_class_views(graph_db=graph_db) + await graph_db.save() + + async def _create_mermaid_class_views(self, graph_db): + path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO + path.mkdir(parents=True, exist_ok=True) + pathname = path / CONFIG.git_repo.workdir.name + async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer: + content = "classDiagram\n" + logger.debug(content) + await writer.write(content) + # class names + rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + class_distinct = set() + relationship_distinct = set() + for r in rows: + await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct) + for r in rows: + await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct) + + @staticmethod + async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct): + fields = split_namespace(ns_class_name) + if len(fields) > 2: + # Ignore sub-class + return + + class_view = ClassView(name=fields[1]) + rows = await graph_db.select(subject=ns_class_name) + for r in rows: + name = split_namespace(r.object_)[-1] + name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python") + if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY: + var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db) + attribute = ClassAttribute( + name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type + ) + class_view.attributes.append(attribute) + elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION: + method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction)) + await RebuildClassView._parse_function_args(method, r.object_, graph_db) + class_view.methods.append(method) + + # update graph db + await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) + + content = class_view.get_mermaid(align=1) + logger.debug(content) + await file_writer.write(content) + distinct.add(ns_class_name) + + @staticmethod + async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct): + s_fields = split_namespace(ns_class_name) + if len(s_fields) > 2: + # Ignore sub-class + return + + predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]} + mappings = { + GENERALIZATION: " <|-- ", + COMPOSITION: " *-- ", + AGGREGATION: " o-- ", + } + content = "" + for p, v in predicates.items(): + rows = await graph_db.select(subject=ns_class_name, predicate=p) + for r in rows: + o_fields = split_namespace(r.object_) + if len(o_fields) > 2: + # Ignore sub-class + continue + relationship = mappings.get(v, " .. ") + link = f"{o_fields[1]}{relationship}{s_fields[1]}" + distinct.add(link) + content += f"\t{link}\n" + + if content: + logger.debug(content) + await file_writer.write(content) + + @staticmethod + def _parse_name(name: str, language="python"): + pattern = re.compile(r"(.*?)<\/I>") + result = re.search(pattern, name) + + abstraction = "" + if result: + name = result.group(1) + abstraction = "*" + if name.startswith("__"): + visibility = "-" + elif name.startswith("_"): + visibility = "#" + else: + visibility = "+" + return name, visibility, abstraction + + @staticmethod + async def _parse_variable_type(ns_name, graph_db) -> str: + rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC) + if not rows: + return "" + vals = rows[0].object_.replace("'", "").split(":") + if len(vals) == 1: + return "" + val = vals[-1].strip() + return "" if val == "NoneType" else val + " " + + @staticmethod + async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository): + rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC) + if not rows: + return + info = rows[0].object_.replace("'", "") + + fs_tag = "(" + ix = info.find(fs_tag) + fe_tag = "):" + eix = info.rfind(fe_tag) + if eix < 0: + fe_tag = ")" + eix = info.rfind(fe_tag) + args_info = info[ix + len(fs_tag) : eix].strip() + method.return_type = info[eix + len(fe_tag) :].strip() + if method.return_type == "None": + method.return_type = "" + if "(" in method.return_type: + method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]") + + # parse args + if not args_info: + return + splitter_ixs = [] + cost = 0 + for i in range(len(args_info)): + if args_info[i] == "[": + cost += 1 + elif args_info[i] == "]": + cost -= 1 + if args_info[i] == "," and cost == 0: + splitter_ixs.append(i) + splitter_ixs.append(len(args_info)) + args = [] + ix = 0 + for eix in splitter_ixs: + args.append(args_info[ix:eix]) + ix = eix + 1 + for arg in args: + parts = arg.strip().split(":") + if len(parts) == 1: + method.args.append(ClassAttribute(name=parts[0].strip())) + continue + method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip())) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 25c4912c3..7377442b5 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -130,7 +130,7 @@ async def run(self, *args, **kwargs) -> CodingContext: if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" - coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) + coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/const.py b/metagpt/const.py index a57be641b..811ff9516 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -126,3 +126,8 @@ def get_metagpt_root(): # Message id IGNORED_MESSAGE_ID = "0" + +# Class Relationship +GENERALIZATION = "Generalize" +COMPOSITION = "Composite" +AGGREGATION = "Aggregate" diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 5e4d67940..9863a29ae 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -12,14 +12,14 @@ import re import subprocess from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional -import aiofiles import pandas as pd from pydantic import BaseModel, Field +from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION from metagpt.logs import logger -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, aread from metagpt.utils.exceptions import handle_exception @@ -46,6 +46,13 @@ class ClassInfo(BaseModel): methods: Dict[str, str] = Field(default_factory=dict) +class ClassRelationship(BaseModel): + src: str = "" + dest: str = "" + relationship: str = "" + label: Optional[str] = None + + class RepoParser(BaseModel): base_directory: Path = Field(default=None) @@ -60,7 +67,8 @@ def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo: file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory))) for node in tree: info = RepoParser.node_to_str(node) - file_info.page_info.append(info) + if info: + file_info.page_info.append(info) if isinstance(node, ast.ClassDef): class_methods = [m.name for m in node.body if is_func(m)] file_info.classes.append({"name": node.name, "methods": class_methods}) @@ -110,7 +118,9 @@ def generate_structure(self, output_path=None, mode="json") -> Path: return output_path @staticmethod - def node_to_str(node) -> (int, int, str, str | Tuple): + def node_to_str(node) -> CodeBlockInfo | None: + if isinstance(node, ast.Try): + return None if any_to_str(node) == any_to_str(ast.Expr): return CodeBlockInfo( lineno=node.lineno, @@ -129,6 +139,7 @@ def node_to_str(node) -> (int, int, str, str | Tuple): }, any_to_str(ast.If): RepoParser._parse_if, any_to_str(ast.AsyncFunctionDef): lambda x: x.name, + any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target), } func = mappings.get(any_to_str(node)) if func: @@ -143,7 +154,8 @@ def node_to_str(node) -> (int, int, str, str | Tuple): else: raise NotImplementedError(f"Not implement:{val}") return code_block - raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") + logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") + return None @staticmethod def _parse_expr(node) -> List: @@ -164,22 +176,51 @@ def _parse_name(n): @staticmethod def _parse_if(n): - tokens = [RepoParser._parse_variable(n.test.left)] - for item in n.test.comparators: - tokens.append(RepoParser._parse_variable(item)) + tokens = [] + try: + if isinstance(n.test, ast.BoolOp): + tokens = [] + for v in n.test.values: + tokens.extend(RepoParser._parse_if_compare(v)) + return tokens + if isinstance(n.test, ast.Compare): + v = RepoParser._parse_variable(n.test.left) + if v: + tokens.append(v) + for item in n.test.comparators: + v = RepoParser._parse_variable(item) + if v: + tokens.append(v) + return tokens + except Exception as e: + logger.warning(f"Unsupported if: {n}, err:{e}") return tokens + @staticmethod + def _parse_if_compare(n): + if hasattr(n, "left"): + return RepoParser._parse_variable(n.left) + else: + return [] + @staticmethod def _parse_variable(node): - funcs = { - any_to_str(ast.Constant): lambda x: x.value, - any_to_str(ast.Name): lambda x: x.id, - any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}", - } - func = funcs.get(any_to_str(node)) - if not func: - raise NotImplementedError(f"Not implement:{node}") - return func(node) + try: + funcs = { + any_to_str(ast.Constant): lambda x: x.value, + any_to_str(ast.Name): lambda x: x.id, + any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}" + if hasattr(x.value, "id") + else f"{x.attr}", + any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func), + any_to_str(ast.Tuple): lambda x: "", + } + func = funcs.get(any_to_str(node)) + if not func: + raise NotImplementedError(f"Not implement:{node}") + return func(node) + except Exception as e: + logger.warning(f"Unsupported variable:{node}, err:{e}") @staticmethod def _parse_assign(node): @@ -197,18 +238,21 @@ async def rebuild_class_views(self, path: str | Path = None): raise ValueError(f"{result}") class_view_pathname = path / "classes.dot" class_views = await self._parse_classes(class_view_pathname) + relationship_views = await self._parse_class_relationships(class_view_pathname) packages_pathname = path / "packages.dot" - class_views = RepoParser._repair_namespaces(class_views=class_views, path=path) + class_views, relationship_views = RepoParser._repair_namespaces( + class_views=class_views, relationship_views=relationship_views, path=path + ) class_view_pathname.unlink(missing_ok=True) packages_pathname.unlink(missing_ok=True) - return class_views + return class_views, relationship_views async def _parse_classes(self, class_view_pathname): class_views = [] if not class_view_pathname.exists(): return class_views - async with aiofiles.open(str(class_view_pathname), mode="r") as reader: - lines = await reader.readlines() + data = await aread(filename=class_view_pathname, encoding="utf-8") + lines = data.split("\n") for line in lines: package_name, info = RepoParser._split_class_line(line) if not package_name: @@ -229,6 +273,19 @@ async def _parse_classes(self, class_view_pathname): class_views.append(class_info) return class_views + async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]: + relationship_views = [] + if not class_view_pathname.exists(): + return relationship_views + data = await aread(filename=class_view_pathname, encoding="utf-8") + lines = data.split("\n") + for line in lines: + relationship = RepoParser._split_relationship_line(line) + if not relationship: + continue + relationship_views.append(relationship) + return relationship_views + @staticmethod def _split_class_line(line): part_splitor = '" [' @@ -247,6 +304,40 @@ def _split_class_line(line): info = re.sub(r"]*>", "\n", info) return class_name, info + @staticmethod + def _split_relationship_line(line): + splitters = [" -> ", " [", "];"] + idxs = [] + for tag in splitters: + if tag not in line: + return None + idxs.append(line.find(tag)) + ret = ClassRelationship() + ret.src = line[0 : idxs[0]].strip('"') + ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"') + properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ") + mappings = { + 'arrowhead="empty"': GENERALIZATION, + 'arrowhead="diamond"': COMPOSITION, + 'arrowhead="odiamond"': AGGREGATION, + } + for k, v in mappings.items(): + if k in properties: + ret.relationship = v + if v != GENERALIZATION: + ret.label = RepoParser._get_label(properties) + break + return ret + + @staticmethod + def _get_label(line): + tag = 'label="' + if tag not in line: + return "" + ix = line.find(tag) + eix = line.find('"', ix + len(tag)) + return line[ix + len(tag) : eix] + @staticmethod def _create_path_mapping(path: str | Path) -> Dict[str, str]: mappings = { @@ -271,7 +362,9 @@ def _create_path_mapping(path: str | Path) -> Dict[str, str]: return mappings @staticmethod - def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]: + def _repair_namespaces( + class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path + ) -> (List[ClassInfo], List[ClassRelationship]): if not class_views: return [] c = class_views[0] @@ -290,7 +383,12 @@ def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[C for c in class_views: c.package = RepoParser._repair_ns(c.package, new_mappings) - return class_views + for i in range(len(relationship_views)): + v = relationship_views[i] + v.src = RepoParser._repair_ns(v.src, new_mappings) + v.dest = RepoParser._repair_ns(v.dest, new_mappings) + relationship_views[i] = v + return class_views, relationship_views @staticmethod def _repair_ns(package, mappings): diff --git a/metagpt/schema.py b/metagpt/schema.py index e36bef395..02d44f767 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -451,3 +451,63 @@ def __hash__(self): class BugFixContext(BaseContext): filename: str = "" + + +# mermaid class view +class ClassMeta(BaseModel): + name: str = "" + abstraction: bool = False + static: bool = False + visibility: str = "" + + +class ClassAttribute(ClassMeta): + value_type: str = "" + default_value: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + if self.value_type: + content += self.value_type + " " + content += self.name + if self.default_value: + content += "=" + if self.value_type not in ["str", "string", "String"]: + content += self.default_value + else: + content += '"' + self.default_value.replace('"', "") + '"' + if self.abstraction: + content += "*" + if self.static: + content += "$" + return content + + +class ClassMethod(ClassMeta): + args: List[ClassAttribute] = Field(default_factory=list) + return_type: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" + if self.return_type: + content += ":" + self.return_type + if self.abstraction: + content += "*" + if self.static: + content += "$" + return content + + +class ClassView(ClassMeta): + attributes: List[ClassAttribute] = Field(default_factory=list) + methods: List[ClassMethod] = Field(default_factory=list) + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" + for v in self.attributes: + content += v.get_mermaid(align=align + 1) + "\n" + for v in self.methods: + content += v.get_mermaid(align=align + 1) + "\n" + content += "".join(["\t" for i in range(align)]) + "}\n" + return content diff --git a/metagpt/tools/metagpt_oas3_api_svc.py b/metagpt/tools/metagpt_oas3_api_svc.py index 319e7efb2..8e9f4a0da 100644 --- a/metagpt/tools/metagpt_oas3_api_svc.py +++ b/metagpt/tools/metagpt_oas3_api_svc.py @@ -5,6 +5,12 @@ @Author : mashenquan @File : metagpt_oas3_api_svc.py @Desc : MetaGPT OpenAPI Specification 3.0 REST API service + + curl -X 'POST' \ + 'http://localhost:8080/openapi/greeting/dave' \ + -H 'accept: text/plain' \ + -H 'Content-Type: application/json' \ + -d '{}' """ from pathlib import Path @@ -15,7 +21,7 @@ def oas_http_svc(): """Start the OAS 3.0 OpenAPI HTTP service""" print("http://localhost:8080/oas3/ui/") - specification_dir = Path(__file__).parent.parent.parent / ".well-known" + specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known" app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("metagpt_oas3_api.yaml") app.add_api("openapi.yaml") diff --git a/metagpt/tools/openapi_v3_hello.py b/metagpt/tools/openapi_v3_hello.py index c8f5de42d..d1c83eac2 100644 --- a/metagpt/tools/openapi_v3_hello.py +++ b/metagpt/tools/openapi_v3_hello.py @@ -23,7 +23,7 @@ async def post_greeting(name: str) -> str: if __name__ == "__main__": - specification_dir = Path(__file__).parent.parent.parent / ".well-known" + specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known" app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir)) app.add_api("openapi.yaml", arguments={"title": "Hello World Example"}) app.run(port=8082) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index c7751c2af..0032f0b0d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -407,6 +407,10 @@ def concat_namespace(*args) -> str: return ":".join(str(value) for value in args) +def split_namespace(ns_class_name: str) -> List[str]: + return ns_class_name.split(":") + + def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: """ Generates a logging function to be used after a call is retried. diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index 08f4327fa..8bb5f9bb3 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -12,9 +12,9 @@ from pathlib import Path from typing import List -import aiofiles import networkx +from metagpt.utils.common import aread, awrite from metagpt.utils.graph_repository import SPO, GraphRepository @@ -55,12 +55,10 @@ async def save(self, path: str | Path = None): if not path.exists(): path.mkdir(parents=True, exist_ok=True) pathname = Path(path) / self.name - async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer: - await writer.write(data) + await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8") async def load(self, pathname: str | Path): - async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader: - data = await reader.read(-1) + data = await aread(filename=pathname, encoding="utf-8") m = json.loads(data) self._repo = networkx.node_link_graph(m) diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index ff750fbbb..0ddca414d 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -138,6 +138,8 @@ def changed_files(self) -> Dict[str, str]: files = self._git_repo.changed_files relative_files = {} for p, ct in files.items(): + if ct.value == "D": # deleted + continue try: rf = Path(p).relative_to(self._relative_path) except ValueError: diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index 37da3dee4..88946c98e 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -13,19 +13,25 @@ from pydantic import BaseModel -from metagpt.repo_parser import ClassInfo, RepoFileInfo +from metagpt.logs import logger +from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo from metagpt.utils.common import concat_namespace class GraphKeyword: IS = "is" + OF = "Of" + ON = "On" CLASS = "class" FUNCTION = "function" + HAS_FUNCTION = "has_function" SOURCE_CODE = "source_code" NULL = "" GLOBAL_VARIABLE = "global_variable" CLASS_FUNCTION = "class_function" CLASS_PROPERTY = "class_property" + HAS_CLASS_FUNCTION = "has_class_function" + HAS_CLASS_PROPERTY = "has_class_property" HAS_CLASS = "has_class" HAS_PAGE_INFO = "has_page_info" HAS_CLASS_VIEW = "has_class_view" @@ -73,11 +79,13 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type) for c in file_info.classes: class_name = c.get("name", "") + # file -> class await graph_db.insert( subject=file_info.file, predicate=GraphKeyword.HAS_CLASS, object_=concat_namespace(file_info.file, class_name), ) + # class detail await graph_db.insert( subject=concat_namespace(file_info.file, class_name), predicate=GraphKeyword.IS, @@ -85,12 +93,22 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: ) methods = c.get("methods", []) for fn in methods: + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name), + predicate=GraphKeyword.HAS_CLASS_FUNCTION, + object_=concat_namespace(file_info.file, class_name, fn), + ) await graph_db.insert( subject=concat_namespace(file_info.file, class_name, fn), predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS_FUNCTION, ) for f in file_info.functions: + # file -> function + await graph_db.insert( + subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f) + ) + # function detail await graph_db.insert( subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION ) @@ -105,13 +123,13 @@ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: await graph_db.insert( subject=concat_namespace(file_info.file, *code_block.tokens), predicate=GraphKeyword.HAS_PAGE_INFO, - object_=code_block.json(ensure_ascii=False), + object_=code_block.model_dump_json(), ) for k, v in code_block.properties.items(): await graph_db.insert( subject=concat_namespace(file_info.file, k, v), predicate=GraphKeyword.HAS_PAGE_INFO, - object_=code_block.json(ensure_ascii=False), + object_=code_block.model_dump_json(), ) @staticmethod @@ -129,6 +147,13 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi object_=GraphKeyword.CLASS, ) for vn, vt in c.attributes.items(): + # class -> property + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.HAS_CLASS_PROPERTY, + object_=concat_namespace(c.package, vn), + ) + # property detail await graph_db.insert( subject=concat_namespace(c.package, vn), predicate=GraphKeyword.IS, @@ -138,6 +163,15 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt ) for fn, desc in c.methods.items(): + if "" in desc and "" not in desc: + logger.error(desc) + # class -> function + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.HAS_CLASS_FUNCTION, + object_=concat_namespace(c.package, fn), + ) + # function detail await graph_db.insert( subject=concat_namespace(c.package, fn), predicate=GraphKeyword.IS, @@ -148,3 +182,19 @@ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_vi predicate=GraphKeyword.HAS_ARGS_DESC, object_=desc, ) + + @staticmethod + async def update_graph_db_with_class_relationship_views( + graph_db: "GraphRepository", relationship_views: List[ClassRelationship] + ): + for r in relationship_views: + await graph_db.insert( + subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest + ) + if not r.label: + continue + await graph_db.insert( + subject=r.src, + predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON, + object_=concat_namespace(r.dest, r.label), + ) diff --git a/tests/conftest.py b/tests/conftest.py index fbf9ff465..a15e3e85b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import logging import os import re +import uuid from typing import Optional import pytest @@ -151,9 +152,9 @@ def emit(self, record): # init & dispose git repo -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def setup_and_teardown_git_repo(request): - CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest") + CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") CONFIG.git_reinit = True # Destroy git repo at the end of the test session. diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py index 955c6ae3b..0103e9d05 100644 --- a/tests/metagpt/actions/test_rebuild_class_view.py +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -11,13 +11,19 @@ import pytest from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.config import CONFIG +from metagpt.const import GRAPH_REPO_FILE_REPO from metagpt.llm import LLM @pytest.mark.asyncio async def test_rebuild(): - action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM()) + action = RebuildClassView( + name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM() + ) await action.run() + graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO) + assert graph_file_repo.changed_files if __name__ == "__main__": diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 816c186e2..b6e334fbe 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -19,6 +19,9 @@ from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, + ClassAttribute, + ClassMethod, + ClassView, CodeSummarizeContext, Document, Message, @@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want): assert want in m +def test_class_view(): + attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True) + assert attr_a.get_mermaid(align=1) == "\t+int a=0*" + attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True) + assert attr_b.get_mermaid(align=0) == '#str b="0"$' + class_view = ClassView(name="A") + class_view.attributes = [attr_a, attr_b] + + method_a = ClassMethod(name="run", visibility="+", abstraction=True) + assert method_a.get_mermaid(align=1) == "\t+run()*" + method_b = ClassMethod( + name="_test", + visibility="#", + static=True, + args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")], + return_type="str", + ) + assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$" + class_view.methods = [method_a, method_b] + assert ( + class_view.get_mermaid(align=0) + == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n' + ) + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py index 1135860eb..5f52b28cc 100644 --- a/tests/metagpt/tools/test_metagpt_oas3_api_svc.py +++ b/tests/metagpt/tools/test_metagpt_oas3_api_svc.py @@ -24,13 +24,14 @@ async def test_oas2_svc(): process = subprocess.Popen(["python", str(script_pathname)], cwd=str(workdir), env=env) await asyncio.sleep(5) - url = "http://localhost:8080/openapi/greeting/dave" - headers = {"accept": "text/plain", "Content-Type": "application/json"} - data = {} - response = requests.post(url, headers=headers, json=data) - assert response.text == "Hello dave\n" - - process.terminate() + try: + url = "http://localhost:8080/openapi/greeting/dave" + headers = {"accept": "text/plain", "Content-Type": "application/json"} + data = {} + response = requests.post(url, headers=headers, json=data) + assert response.text == "Hello dave\n" + finally: + process.terminate() if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_hello.py b/tests/metagpt/tools/test_openapi_v3_hello.py similarity index 65% rename from tests/metagpt/tools/test_hello.py rename to tests/metagpt/tools/test_openapi_v3_hello.py index 7e61532ab..5726cf8e0 100644 --- a/tests/metagpt/tools/test_hello.py +++ b/tests/metagpt/tools/test_openapi_v3_hello.py @@ -3,7 +3,7 @@ """ @Time : 2023/12/26 @Author : mashenquan -@File : test_hello.py +@File : test_openapi_v3_hello.py """ import asyncio import subprocess @@ -24,13 +24,14 @@ async def test_hello(): process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env) await asyncio.sleep(5) - url = "http://localhost:8082/openapi/greeting/dave" - headers = {"accept": "text/plain", "Content-Type": "application/json"} - data = {} - response = requests.post(url, headers=headers, json=data) - assert response.text == "Hello dave\n" - - process.terminate() + try: + url = "http://localhost:8082/openapi/greeting/dave" + headers = {"accept": "text/plain", "Content-Type": "application/json"} + data = {} + response = requests.post(url, headers=headers, json=data) + assert response.text == "Hello dave\n" + finally: + process.terminate() if __name__ == "__main__": diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 0342a92af..9b1fa878e 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -36,6 +36,7 @@ read_file_block, read_json_file, require_python_version, + split_namespace, ) @@ -163,6 +164,23 @@ def test_concat_namespace(self): assert concat_namespace("a", "b", "c", "e") == "a:b:c:e" assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f" + @pytest.mark.parametrize( + ("val", "want"), + [ + ( + "tests/metagpt/test_role.py:test_react:Input:subscription", + ["tests/metagpt/test_role.py", "test_react", "Input", "subscription"], + ), + ( + "tests/metagpt/test_role.py:test_react:Input:goal", + ["tests/metagpt/test_role.py", "test_react", "Input", "goal"], + ), + ], + ) + def test_split_namespace(self, val, want): + res = split_namespace(val) + assert res == want + def test_read_json_file(self): assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8") with pytest.raises(FileNotFoundError): diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py index b93ff0cdb..d499418ac 100644 --- a/tests/metagpt/utils/test_redis.py +++ b/tests/metagpt/utils/test_redis.py @@ -6,20 +6,34 @@ @File : test_redis.py """ +import mock import pytest from metagpt.config import CONFIG from metagpt.utils.redis import Redis +async def async_mock_from_url(*args, **kwargs): + mock_client = mock.AsyncMock() + mock_client.set.return_value = None + mock_client.get.side_effect = [b"test", b""] + return mock_client + + @pytest.mark.asyncio -async def test_redis(): +@mock.patch("aioredis.from_url", return_value=async_mock_from_url()) +async def test_redis(mock_from_url): + # Mock + # mock_client = mock.AsyncMock() + # mock_client.set.return_value=None + # mock_client.get.side_effect = [b'test', b''] + # mock_from_url.return_value = mock_client + # Prerequisites - assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" - assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" - # assert CONFIG.REDIS_USER - assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD" - assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based" + CONFIG.REDIS_HOST = "MOCK_REDIS_HOST" + CONFIG.REDIS_PORT = "MOCK_REDIS_PORT" + CONFIG.REDIS_PASSWORD = "MOCK_REDIS_PASSWORD" + CONFIG.REDIS_DB = 0 conn = Redis() assert not conn.is_valid diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index f74e7b52a..132aa0635 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -9,20 +9,36 @@ from pathlib import Path import aiofiles +import mock import pytest from metagpt.config import CONFIG +from metagpt.utils.common import aread from metagpt.utils.s3 import S3 @pytest.mark.asyncio -async def test_s3(): +@mock.patch("aioboto3.Session") +async def test_s3(mock_session_class): + # Set up the mock response + data = await aread(__file__, "utf-8") + mock_session_object = mock.Mock() + reader_mock = mock.AsyncMock() + reader_mock.read.side_effect = [data.encode("utf-8"), b"", data.encode("utf-8")] + type(reader_mock).url = mock.PropertyMock(return_value="https://mock") + mock_client = mock.AsyncMock() + mock_client.put_object.return_value = None + mock_client.get_object.return_value = {"Body": reader_mock} + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_session_object.client.return_value = mock_client + mock_session_class.return_value = mock_session_object + # Prerequisites - assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY" - assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY" - assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL" - # assert CONFIG.S3_SECURE: true # true/false - assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET" + # assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY" + # assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY" + # assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL" + # assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET" conn = S3() assert conn.is_valid @@ -42,6 +58,7 @@ async def test_s3(): assert "http" in res # Mock session env + type(reader_mock).url = mock.PropertyMock(return_value="") old_options = CONFIG.options.copy() new_options = old_options.copy() new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY" @@ -54,6 +71,8 @@ async def test_s3(): finally: CONFIG.set_context(old_options) + await reader.close() + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_session.py b/tests/metagpt/utils/test_session.py new file mode 100644 index 000000000..eab2587a2 --- /dev/null +++ b/tests/metagpt/utils/test_session.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +import pytest + + +def test_nodeid(request): + print(request.node.nodeid) + assert request.node.nodeid + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])