diff --git a/config/config2.yaml.example b/config/config2.yaml.example index bead3c626..8f4a33fc1 100644 --- a/config/config2.yaml.example +++ b/config/config2.yaml.example @@ -11,6 +11,10 @@ search: api_key: "YOUR_API_KEY" cse_id: "YOUR_CSE_ID" +browser: + engine: "playwright" # playwright/selenium + browser_type: "chromium" # playwright: chromium/firefox/webkit; selenium: chrome/firefox/edge/ie + mermaid: engine: "pyppeteer" path: "/Applications/Google Chrome.app" diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 9406a2965..97b1378ee 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -5,17 +5,20 @@ import asyncio from metagpt.roles import Searcher -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine, SearchEngineType async def main(): question = "What are the most interesting human facts?" + kwargs = {"api_key": "", "cse_id": "", "proxy": None} # Serper API - # await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question) # SerpAPI - await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question) # Google API - # await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question) + # DDG API + await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question) if __name__ == "__main__": diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2755628c9..2ebeadb66 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -3,15 +3,15 @@ from __future__ import annotations import asyncio -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from pydantic import Field, parse_obj_as +from pydantic import TypeAdapter, model_validator from metagpt.actions import Action from metagpt.config2 import config from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine -from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType +from metagpt.tools.web_browser_engine import WebBrowserEngine from metagpt.utils.common import OutputParser from metagpt.utils.text import generate_prompt_chunk, reduce_message_length @@ -81,10 +81,16 @@ class CollectLinks(Action): name: str = "CollectLinks" i_context: Optional[str] = None desc: str = "Collect links from a search engine." - - search_engine: SearchEngine = Field(default_factory=SearchEngine) + search_func: Optional[Any] = None + search_engine: Optional[SearchEngine] = None rank_func: Optional[Callable[[list[str]], None]] = None + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.search_engine is None: + self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy) + return self + async def run( self, topic: str, @@ -107,7 +113,7 @@ async def run( keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) try: keywords = OutputParser.extract_struct(keywords, list) - keywords = parse_obj_as(list[str], keywords) + keywords = TypeAdapter(list[str]).validate_python(keywords) except Exception as e: logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}") keywords = [topic] @@ -133,7 +139,7 @@ def gen_msg(): queries = await self._aask(prompt, [system_text]) try: queries = OutputParser.extract_struct(queries, list) - queries = parse_obj_as(list[str], queries) + queries = TypeAdapter(list[str]).validate_python(queries) except Exception as e: logger.exception(f"fail to break down the research question due to {e}") queries = keywords @@ -178,15 +184,17 @@ class WebBrowseAndSummarize(Action): i_context: Optional[str] = None desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None - web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.web_browser_engine = WebBrowserEngine( - engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT, - run_func=self.browse_func, - ) + web_browser_engine: Optional[WebBrowserEngine] = None + + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.web_browser_engine is None: + self.web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + browse_func=self.browse_func, + proxy=self.config.proxy, + ) + return self async def run( self, diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 59b35cd58..7eed7381b 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : search_google.py """ -from typing import Any, Optional +from typing import Optional import pydantic from pydantic import model_validator @@ -13,7 +13,6 @@ from metagpt.actions import Action from metagpt.logs import logger from metagpt.schema import Message -from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements @@ -105,21 +104,19 @@ class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - engine: Optional[SearchEngineType] = None - search_func: Optional[Any] = None search_engine: SearchEngine = None result: str = "" @model_validator(mode="after") - def validate_engine_and_run_func(self): - if self.engine is None: - self.engine = self.config.search_engine - try: - search_engine = SearchEngine(engine=self.engine, run_func=self.search_func) - except pydantic.ValidationError: - search_engine = None - - self.search_engine = search_engine + def validate_search_engine(self): + if self.search_engine is None: + try: + config = self.config + search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy) + except pydantic.ValidationError: + search_engine = None + + self.search_engine = search_engine return self async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: diff --git a/metagpt/config2.py b/metagpt/config2.py index 21c17f7be..bc6af18c6 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -51,7 +51,7 @@ class Config(CLIParams, YamlModel): proxy: str = "" # Tool Parameters - search: Optional[SearchConfig] = None + search: SearchConfig = SearchConfig() browser: BrowserConfig = BrowserConfig() mermaid: MermaidConfig = MermaidConfig() diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py index 00f918735..2f8024f44 100644 --- a/metagpt/configs/browser_config.py +++ b/metagpt/configs/browser_config.py @@ -15,6 +15,6 @@ class BrowserConfig(YamlModel): """Config for Browser""" engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT - browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome" - driver: Literal["chromium", "firefox", "webkit"] = "chromium" - path: str = "" + browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium" + """If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value + should be either "chrome", "firefox", "edge", or "ie".""" diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index a8ae918db..af928b02a 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : search_config.py """ +from typing import Callable, Optional + from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel @@ -12,6 +14,7 @@ class SearchConfig(YamlModel): """Config for Search""" - api_key: str - api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO + api_key: str = "" cse_id: str = "" # for google + search_func: Optional[Callable] = None diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index bdf2d0734..060150f4d 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -7,7 +7,7 @@ """ from typing import Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from metagpt.config2 import Config from metagpt.context import Context @@ -17,7 +17,7 @@ class ContextMixin(BaseModel): """Mixin class for context and config""" - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") # Pydantic has bug on _private_attr when using inheritance, so we use private_* instead # - https://github.com/pydantic/pydantic/issues/7142 @@ -32,15 +32,18 @@ class ContextMixin(BaseModel): # Env/Role/Action will use this llm as private llm, or use self.context._llm instance private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) - def __init__( + @model_validator(mode="after") + def validate_extra(self): + self._process_extra(**(self.model_extra or {})) + return self + + def _process_extra( self, context: Optional[Context] = None, config: Optional[Config] = None, llm: Optional[BaseLLM] = None, - **kwargs, ): - """Initialize with config""" - super().__init__(**kwargs) + """Process the extra field""" self.set_context(context) self.set_config(config) self.set_llm(llm) diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 7929ce7fe..bc449b5cd 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -8,12 +8,12 @@ from typing import Optional -from pydantic import Field +from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize, UserRequirement from metagpt.document_store.base_store import BaseStore from metagpt.roles import Role -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine class Sales(Role): @@ -29,14 +29,13 @@ class Sales(Role): store: Optional[BaseStore] = Field(default=None, exclude=True) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._set_store(self.store) - - def _set_store(self, store): - if store: - action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) + @model_validator(mode="after") + def validate_stroe(self): + if self.store: + search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy) + action = SearchAndSummarize(search_engine=search_engine, context=self.context) else: - action = SearchAndSummarize() + action = SearchAndSummarize self.set_actions([action]) self._watch([UserRequirement]) + return self diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 19a73a40e..557c5ae95 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -8,7 +8,9 @@ the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ -from pydantic import Field +from typing import Optional + +from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize from metagpt.actions.action_node import ActionNode @@ -16,7 +18,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine class Searcher(Role): @@ -28,33 +30,22 @@ class Searcher(Role): profile (str): Role profile. goal (str): Goal of the searcher. constraints (str): Constraints or limitations for the searcher. - engine (SearchEngineType): The type of search engine to use. + search_engine (SearchEngine): The search engine to use. """ name: str = Field(default="Alice") profile: str = Field(default="Smart Assistant") goal: str = "Provide search services for users" constraints: str = "Answer is rich and complete" - engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE - - def __init__(self, **kwargs) -> None: - """ - Initializes the Searcher role with given attributes. + search_engine: Optional[SearchEngine] = None - Args: - name (str): Name of the searcher. - profile (str): Role profile. - goal (str): Goal of the searcher. - constraints (str): Constraints or limitations for the searcher. - engine (SearchEngineType): The type of search engine to use. - """ - super().__init__(**kwargs) - self.set_actions([SearchAndSummarize(engine=self.engine)]) - - def set_search_func(self, search_func): - """Sets a custom search function for the searcher.""" - action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) - self.set_actions([action]) + @model_validator(mode="after") + def post_root(self): + if self.search_engine: + self.set_actions([SearchAndSummarize(search_engine=self.search_engine, context=self.context)]) + else: + self.set_actions([SearchAndSummarize]) + return self async def _act_sp(self) -> Message: """Performs the search action in a single process.""" diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 0d0db9147..1e540bd0e 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -8,14 +8,23 @@ import importlib from typing import Callable, Coroutine, Literal, Optional, Union, overload +from pydantic import BaseModel, ConfigDict, model_validator from semantic_kernel.skill_definition import sk_function +from metagpt.configs.search_config import SearchConfig +from metagpt.logs import logger from metagpt.tools import SearchEngineType class SkSearchEngine: - def __init__(self): - self.search_engine = SearchEngine() + """A search engine class for executing searches. + + Attributes: + search_engine: The search engine instance used for executing searches. + """ + + def __init__(self, **kwargs): + self.search_engine = SearchEngine(**kwargs) @sk_function( description="searches results from Google. Useful when you need to find short " @@ -28,43 +37,85 @@ async def run(self, query: str) -> str: return result -class SearchEngine: - """Class representing a search engine. - - Args: - engine: The search engine type. Defaults to the search engine specified in the config. - run_func: The function to run the search. Defaults to None. +class SearchEngine(BaseModel): + """A model for configuring and executing searches with different search engines. Attributes: - run_func: The function to run the search. - engine: The search engine type. + model_config: Configuration for the model allowing arbitrary types. + engine: The type of search engine to use. + run_func: An optional callable for running the search. If not provided, it will be determined based on the engine. + api_key: An optional API key for the search engine. + proxy: An optional proxy for the search engine requests. """ - def __init__( + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE + run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None + api_key: Optional[str] = None + proxy: Optional[str] = None + + @model_validator(mode="after") + def validate_extra(self): + """Validates extra fields provided to the model and updates the run function accordingly.""" + data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True) + if self.model_extra: + data.update(self.model_extra) + self._process_extra(**data) + return self + + def _process_extra( self, - engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE, - run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None, + run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None, **kwargs, ): - if engine == SearchEngineType.SERPAPI_GOOGLE: + """Processes extra configuration and updates the run function based on the search engine type. + + Args: + run_func: An optional callable for running the search. If not provided, it will be determined based on the engine. + """ + if self.engine == SearchEngineType.SERPAPI_GOOGLE: module = "metagpt.tools.search_engine_serpapi" run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run - elif engine == SearchEngineType.SERPER_GOOGLE: + elif self.engine == SearchEngineType.SERPER_GOOGLE: module = "metagpt.tools.search_engine_serper" run_func = importlib.import_module(module).SerperWrapper(**kwargs).run - elif engine == SearchEngineType.DIRECT_GOOGLE: + elif self.engine == SearchEngineType.DIRECT_GOOGLE: module = "metagpt.tools.search_engine_googleapi" run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run - elif engine == SearchEngineType.DUCK_DUCK_GO: + elif self.engine == SearchEngineType.DUCK_DUCK_GO: module = "metagpt.tools.search_engine_ddg" run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run - elif engine == SearchEngineType.CUSTOM_ENGINE: - pass # run_func = run_func + elif self.engine == SearchEngineType.CUSTOM_ENGINE: + run_func = self.run_func else: raise NotImplementedError - self.engine = engine self.run_func = run_func + @classmethod + def from_search_config(cls, config: SearchConfig, **kwargs): + """Creates a SearchEngine instance from a SearchConfig. + + Args: + config: The search configuration to use for creating the SearchEngine instance. + """ + data = config.model_dump(exclude={"api_type", "search_func"}) + if config.search_func is not None: + data["run_func"] = config.search_func + + return cls(engine=config.api_type, **data, **kwargs) + + @classmethod + def from_search_func( + cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs + ): + """Creates a SearchEngine instance from a custom search function. + + Args: + search_func: A callable that executes the search. + """ + return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs) + @overload def run( self, @@ -83,15 +134,29 @@ def run( ) -> list[dict[str, str]]: ... - async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]: + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + ignore_errors: bool = False, + ) -> Union[str, list[dict[str, str]]]: """Run a search query. Args: query: The search query. max_results: The maximum number of results to return. Defaults to 8. as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True. + ignore_errors: Whether to ignore errors during the search. Defaults to False. Returns: The search results as a string or a list of dictionaries. """ - return await self.run_func(query, max_results=max_results, as_string=as_string) + try: + return await self.run_func(query, max_results=max_results, as_string=as_string) + except Exception as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + if not ignore_errors: + raise e + return "" if as_string else [] diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py index 3d004a4ee..1412f20cf 100644 --- a/metagpt/tools/search_engine_ddg.py +++ b/metagpt/tools/search_engine_ddg.py @@ -5,9 +5,9 @@ import asyncio import json from concurrent import futures -from typing import Literal, overload +from typing import Literal, Optional, overload -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict try: from duckduckgo_search import DDGS @@ -18,24 +18,16 @@ ) -class DDGAPIWrapper: - """Wrapper around duckduckgo_search API. +class DDGAPIWrapper(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) - To use this module, you should have the `duckduckgo_search` Python package installed. - """ + loop: Optional[asyncio.AbstractEventLoop] = None + executor: Optional[futures.Executor] = None + proxy: Optional[str] = None - def __init__( - self, - *, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, - ): - kwargs = {} - if config.proxy: - kwargs["proxies"] = config.proxy - self.loop = loop - self.executor = executor - self.ddgs = DDGS(**kwargs) + @property + def ddgs(self): + return DDGS(proxies=self.proxy) @overload def run( diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index 0a8f796cb..66b5ba950 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -4,19 +4,16 @@ import asyncio import json +import warnings from concurrent import futures from typing import Optional from urllib.parse import urlparse import httplib2 -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config -from metagpt.logs import logger +from pydantic import BaseModel, ConfigDict, model_validator try: from googleapiclient.discovery import build - from googleapiclient.errors import HttpError except ImportError: raise ImportError( "To use this module, you should have the `google-api-python-client` Python package installed. " @@ -27,40 +24,41 @@ class GoogleAPIWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - google_api_key: Optional[str] = Field(default=None, validate_default=True) - google_cse_id: Optional[str] = Field(default=None, validate_default=True) + api_key: str + cse_id: str loop: Optional[asyncio.AbstractEventLoop] = None executor: Optional[futures.Executor] = None + proxy: Optional[str] = None - @field_validator("google_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_google_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_google(cls, values: dict) -> dict: + if "google_api_key" in values: + values.setdefault("api_key", values["google_api_key"]) + warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the google_api_key when constructing an object. Alternatively, " - "ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain " + "To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://console.cloud.google.com/apis/credentials." ) - return val - @field_validator("google_cse_id", mode="before") - @classmethod - def check_google_cse_id(cls, val: str): - val = val or config.search.cse_id - if not val: + if "google_cse_id" in values: + values.setdefault("cse_id", values["google_cse_id"]) + warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2) + + if "cse_id" not in values: raise ValueError( - "To use, make sure you provide the google_cse_id when constructing an object. Alternatively, " - "ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain " - "an API key from https://programmablesearchengine.google.com/controlpanel/create." + "To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain " + "the cse_id from https://programmablesearchengine.google.com/controlpanel/create." ) - return val + return values @property def google_api_client(self): - build_kwargs = {"developerKey": self.google_api_key} - if config.proxy: - parse_result = urlparse(config.proxy) + build_kwargs = {"developerKey": self.api_key} + if self.proxy: + parse_result = urlparse(self.proxy) proxy_type = parse_result.scheme if proxy_type == "https": proxy_type = "http" @@ -96,17 +94,11 @@ async def run( """ loop = self.loop or asyncio.get_event_loop() future = loop.run_in_executor( - self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute + self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute ) - try: - result = await future - # Extract the search result items from the response - search_results = result.get("items", []) - - except HttpError as e: - # Handle errors in the API call - logger.exception(f"fail to search {query} for {e}") - search_results = [] + result = await future + # Extract the search result items from the response + search_results = result.get("items", []) focus = focus or ["snippet", "link", "title"] details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 8d27d493d..5744b1b62 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -5,18 +5,17 @@ @Author : alexanderwu @File : search_engine_serpapi.py """ +import warnings from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict, Field, model_validator class SerpAPIWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - search_engine: Any = None #: :meta private: + api_key: str params: dict = Field( default_factory=lambda: { "engine": "google", @@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) - # should add `validate_default=True` to check with default value - serpapi_api_key: Optional[str] = Field(default=None, validate_default=True) aiosession: Optional[aiohttp.ClientSession] = None + proxy: Optional[str] = None - @field_validator("serpapi_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_serpapi_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_serpapi(cls, values: dict) -> dict: + if "serpapi_api_key" in values: + values.setdefault("api_key", values["serpapi_api_key"]) + warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, " - "ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain " - "an API key from https://serpapi.com/." + "To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain" + " an API key from https://serpapi.com/." ) - return val + return values async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" @@ -60,11 +60,11 @@ def construct_url_and_params() -> Tuple[str, Dict[str, str]]: url, params = construct_url_and_params() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(url, params=params) as response: + async with session.get(url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get(url, params=params) as response: + async with self.aiosession.get(url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() @@ -73,7 +73,7 @@ def construct_url_and_params() -> Tuple[str, Dict[str, str]]: def get_params(self, query: str) -> Dict[str, str]: """Get parameters for SerpAPI.""" _params = { - "api_key": self.serpapi_api_key, + "api_key": self.api_key, "q": query, } params = {**self.params, **_params} diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 71ee2f4f9..ba2fb4f93 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -6,33 +6,34 @@ @File : search_engine_serpapi.py """ import json +import warnings from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict, Field, model_validator class SerperWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - search_engine: Any = None #: :meta private: + api_key: str payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10}) - serper_api_key: Optional[str] = Field(default=None, validate_default=True) aiosession: Optional[aiohttp.ClientSession] = None + proxy: Optional[str] = None - @field_validator("serper_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_serper_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_serper(cls, values: dict) -> dict: + if "serper_api_key" in values: + values.setdefault("api_key", values["serper_api_key"]) + warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the serper_api_key when constructing an object. Alternatively, " - "ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain " + "To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://serper.dev/." ) - return val + return values async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" @@ -54,11 +55,11 @@ def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: url, payloads, headers = construct_url_and_payload_and_headers() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.post(url, data=payloads, headers=headers) as response: + async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get.post(url, data=payloads, headers=headers) as response: + async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() @@ -76,7 +77,7 @@ def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]: return json.dumps(payloads, sort_keys=True) def get_headers(self) -> Dict[str, str]: - headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"} + headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} return headers @staticmethod diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index 411c1604b..01339e51a 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -1,36 +1,95 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - from __future__ import annotations import importlib -from typing import Any, Callable, Coroutine, overload +from typing import Any, Callable, Coroutine, Optional, Union, overload + +from pydantic import BaseModel, ConfigDict, model_validator +from metagpt.configs.browser_config import BrowserConfig from metagpt.tools import WebBrowserEngineType from metagpt.utils.parse_html import WebPage -class WebBrowserEngine: - def __init__( - self, - engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT, - run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None, - ): - if engine is None: - raise NotImplementedError +class WebBrowserEngine(BaseModel): + """Defines a web browser engine configuration for automated browsing and data extraction. + + This class encapsulates the configuration and operational logic for different web browser engines, + such as Playwright, Selenium, or custom implementations. It provides a unified interface to run + browser automation tasks. + + Attributes: + model_config: Configuration dictionary allowing arbitrary types and extra fields. + engine: The type of web browser engine to use. + run_func: An optional coroutine function to run the browser engine. + proxy: An optional proxy server URL to use with the browser engine. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None + proxy: Optional[str] = None + + @model_validator(mode="after") + def validate_extra(self): + """Validates and processes extra configuration data after model initialization. + + This method is automatically called by Pydantic to validate and process any extra configuration + data provided to the model. It ensures that the extra data is properly integrated into the model's + configuration and operational logic. + + Returns: + The instance itself after processing the extra data. + """ + data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True) + if self.model_extra: + data.update(self.model_extra) + self._process_extra(**data) + return self - if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT: + def _process_extra(self, **kwargs): + """Processes extra configuration data to set up the browser engine run function. + + Depending on the specified engine type, this method dynamically imports and configures + the appropriate browser engine wrapper and its run function. + + Args: + **kwargs: Arbitrary keyword arguments representing extra configuration data. + + Raises: + NotImplementedError: If the engine type is not supported. + """ + if self.engine is WebBrowserEngineType.PLAYWRIGHT: module = "metagpt.tools.web_browser_engine_playwright" - run_func = importlib.import_module(module).PlaywrightWrapper().run - elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM: + run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run + elif self.engine is WebBrowserEngineType.SELENIUM: module = "metagpt.tools.web_browser_engine_selenium" - run_func = importlib.import_module(module).SeleniumWrapper().run - elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM: - run_func = run_func + run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run + elif self.engine is WebBrowserEngineType.CUSTOM: + run_func = self.run_func else: raise NotImplementedError self.run_func = run_func - self.engine = engine + + @classmethod + def from_browser_config(cls, config: BrowserConfig, **kwargs): + """Creates a WebBrowserEngine instance from a BrowserConfig object and additional keyword arguments. + + This class method facilitates the creation of a WebBrowserEngine instance by extracting + configuration data from a BrowserConfig object and optionally merging it with additional + keyword arguments. + + Args: + config: A BrowserConfig object containing base configuration data. + **kwargs: Optional additional keyword arguments to override or extend the configuration. + + Returns: + A new instance of WebBrowserEngine configured according to the provided arguments. + """ + data = config.model_dump() + return cls(**data, **kwargs) @overload async def run(self, url: str) -> WebPage: @@ -41,4 +100,16 @@ async def run(self, url: str, *urls: str) -> list[WebPage]: ... async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + """Runs the browser engine to load one or more web pages. + + This method is the implementation of the overloaded run signatures. It delegates the task + of loading web pages to the configured run function, handling either a single URL or multiple URLs. + + Args: + url: The URL of the first web page to load. + *urls: Additional URLs of web pages to load, if any. + + Returns: + A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided. + """ return await self.run_func(url, *urls) diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py index f8dabd5ac..2df288b1a 100644 --- a/metagpt/tools/web_browser_engine_playwright.py +++ b/metagpt/tools/web_browser_engine_playwright.py @@ -6,16 +6,16 @@ import asyncio import sys from pathlib import Path -from typing import Literal +from typing import Literal, Optional from playwright.async_api import async_playwright +from pydantic import BaseModel, Field, PrivateAttr -from metagpt.config2 import config from metagpt.logs import logger from metagpt.utils.parse_html import WebPage -class PlaywrightWrapper: +class PlaywrightWrapper(BaseModel): """Wrapper around Playwright. To use this module, you should have the `playwright` Python package installed and ensure that @@ -24,24 +24,23 @@ class PlaywrightWrapper: command `playwright install` for the first time. """ - def __init__( - self, - browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium", - launch_kwargs: dict | None = None, - **kwargs, - ) -> None: - self.browser_type = browser_type - launch_kwargs = launch_kwargs or {} - if config.proxy and "proxy" not in launch_kwargs: + browser_type: Literal["chromium", "firefox", "webkit"] = "chromium" + launch_kwargs: dict = Field(default_factory=dict) + proxy: Optional[str] = None + context_kwargs: dict = Field(default_factory=dict) + _has_run_precheck: bool = PrivateAttr(False) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + launch_kwargs = self.launch_kwargs + if self.proxy and "proxy" not in launch_kwargs: args = launch_kwargs.get("args", []) if not any(str.startswith(i, "--proxy-server=") for i in args): - launch_kwargs["proxy"] = {"server": config.proxy} - self.launch_kwargs = launch_kwargs - context_kwargs = {} + launch_kwargs["proxy"] = {"server": self.proxy} + if "ignore_https_errors" in kwargs: - context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"] - self._context_kwargs = context_kwargs - self._has_run_precheck = False + self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"] async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: async with async_playwright() as ap: @@ -55,7 +54,7 @@ async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: return await _scrape(browser, url) async def _scrape(self, browser, url): - context = await browser.new_context(**self._context_kwargs) + context = await browser.new_context(**self.context_kwargs) page = await context.new_page() async with page: try: @@ -75,8 +74,8 @@ async def _run_precheck(self, browser_type): executable_path = Path(browser_type.executable_path) if not executable_path.exists() and "executable_path" not in self.launch_kwargs: kwargs = {} - if config.proxy: - kwargs["env"] = {"ALL_PROXY": config.proxy} + if self.proxy: + kwargs["env"] = {"ALL_PROXY": self.proxy} await _install_browsers(self.browser_type, **kwargs) if self._has_run_precheck: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 02dd5c173..3b1682291 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -7,19 +7,19 @@ import importlib from concurrent import futures from copy import deepcopy -from typing import Literal +from typing import Callable, Literal, Optional +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait from webdriver_manager.core.download_manager import WDMDownloadManager from webdriver_manager.core.http import WDMHttpClient -from metagpt.config2 import config from metagpt.utils.parse_html import WebPage -class SeleniumWrapper: +class SeleniumWrapper(BaseModel): """Wrapper around Selenium. To use this module, you should check the following: @@ -31,25 +31,28 @@ class SeleniumWrapper: can scrape web pages using the Selenium WebBrowserEngine. """ - def __init__( - self, - browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome", - launch_kwargs: dict | None = None, - *, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, - ) -> None: - self.browser_type = browser_type - launch_kwargs = launch_kwargs or {} - if config.proxy and "proxy-server" not in launch_kwargs: - launch_kwargs["proxy-server"] = config.proxy - - self.executable_path = launch_kwargs.pop("executable_path", None) - self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()] - self._has_run_precheck = False - self._get_driver = None - self.loop = loop - self.executor = executor + model_config = ConfigDict(arbitrary_types_allowed=True) + + browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome" + launch_kwargs: dict = Field(default_factory=dict) + proxy: Optional[str] = None + loop: Optional[asyncio.AbstractEventLoop] = None + executor: Optional[futures.Executor] = None + _has_run_precheck: bool = PrivateAttr(False) + _get_driver: Optional[Callable] = PrivateAttr(None) + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + if self.proxy and "proxy-server" not in self.launch_kwargs: + self.launch_kwargs["proxy-server"] = self.proxy + + @property + def launch_args(self): + return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"] + + @property + def executable_path(self): + return self.launch_kwargs.get("executable_path") async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: await self._run_precheck() @@ -66,7 +69,9 @@ async def _run_precheck(self): self.loop = self.loop or asyncio.get_event_loop() self._get_driver = await self.loop.run_in_executor( self.executor, - lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path), + lambda: _gen_get_driver_func( + self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy + ), ) self._has_run_precheck = True @@ -92,13 +97,17 @@ def _scrape_website(self, url): class WDMHttpProxyClient(WDMHttpClient): + def __init__(self, proxy: str = None): + super().__init__() + self.proxy = proxy + def get(self, url, **kwargs): - if "proxies" not in kwargs and config.proxy: - kwargs["proxies"] = {"all_proxy": config.proxy} + if "proxies" not in kwargs and self.proxy: + kwargs["proxies"] = {"all_proxy": self.proxy} return super().get(url, **kwargs) -def _gen_get_driver_func(browser_type, *args, executable_path=None): +def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None): WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver") Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service") Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options") @@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): if not executable_path: module_name, type_name = _webdriver_manager_types[browser_type] DriverManager = getattr(importlib.import_module(module_name), type_name) - driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient())) + driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy))) # driver_manager.driver_cache.find_driver(driver_manager.driver)) executable_path = driver_manager.install() diff --git a/tests/metagpt/actions/test_research.py b/tests/metagpt/actions/test_research.py index 372a1e876..ed83ce58c 100644 --- a/tests/metagpt/actions/test_research.py +++ b/tests/metagpt/actions/test_research.py @@ -28,9 +28,9 @@ async def mock_llm_ask(self, prompt: str, system_msgs): return "[1,2]" mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask) - resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), context=context).run( - "The application of MetaGPT" - ) + resp = await research.CollectLinks( + search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), context=context + ).run("The application of MetaGPT") for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]: assert i in resp @@ -50,7 +50,9 @@ def rank_func(results): mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask) resp = await research.CollectLinks( - search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func, context=context + search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), + rank_func=rank_func, + context=context, ).run("The application of MetaGPT") for x, y, z in zip(rank_before, rank_after, resp.values()): assert x[::-1] == y diff --git a/tests/metagpt/learn/test_google_search.py b/tests/metagpt/learn/test_google_search.py index 7fda6436a..70a146878 100644 --- a/tests/metagpt/learn/test_google_search.py +++ b/tests/metagpt/learn/test_google_search.py @@ -16,6 +16,6 @@ class Input(BaseModel): result = await google_search( seed.input, engine=SearchEngineType.SERPER_GOOGLE, - serper_api_key="mock-serper-key", + api_key="mock-serper-key", ) assert result != "" diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py index af81777ac..ba05e1296 100644 --- a/tests/metagpt/roles/test_researcher.py +++ b/tests/metagpt/roles/test_researcher.py @@ -36,7 +36,7 @@ async def test_researcher(mocker, search_engine_mocker, context): role = researcher.Researcher(context=context) for i in role.actions: if isinstance(i, CollectLinks): - i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO) + i.search_engine = SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO) await role.run(topic) assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report") diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 966f53a38..a1f03ef7b 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -12,6 +12,7 @@ import pytest from metagpt.config2 import config +from metagpt.configs.search_config import SearchConfig from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -49,27 +50,34 @@ async def test_search_engine( search_engine_mocker, ): # Prerequisites - search_engine_config = {} + search_engine_config = {"engine": search_engine_type, "run_func": run_func} if search_engine_type is SearchEngineType.SERPAPI_GOOGLE: assert config.search - search_engine_config["serpapi_api_key"] = "mock-serpapi-key" + search_engine_config["api_key"] = "mock-serpapi-key" elif search_engine_type is SearchEngineType.DIRECT_GOOGLE: assert config.search - search_engine_config["google_api_key"] = "mock-google-key" - search_engine_config["google_cse_id"] = "mock-google-cse" + search_engine_config["api_key"] = "mock-google-key" + search_engine_config["cse_id"] = "mock-google-cse" elif search_engine_type is SearchEngineType.SERPER_GOOGLE: assert config.search - search_engine_config["serper_api_key"] = "mock-serper-key" + search_engine_config["api_key"] = "mock-serper-key" - search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config) - rsp = await search_engine.run("metagpt", max_results, as_string) - logger.info(rsp) - if as_string: - assert isinstance(rsp, str) - else: - assert isinstance(rsp, list) - assert len(rsp) <= max_results + async def test(search_engine): + rsp = await search_engine.run("metagpt", max_results, as_string) + logger.info(rsp) + if as_string: + assert isinstance(rsp, str) + else: + assert isinstance(rsp, list) + assert len(rsp) <= max_results + + await test(SearchEngine(**search_engine_config)) + search_engine_config["api_type"] = search_engine_config.pop("engine") + if run_func: + await test(SearchEngine.from_search_func(run_func)) + search_engine_config["search_func"] = search_engine_config.pop("run_func") + await test(SearchEngine.from_search_config(SearchConfig(**search_engine_config))) if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index 0e838a2f8..f35848cf4 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -3,7 +3,6 @@ import pytest -from metagpt.config2 import config from metagpt.tools import web_browser_engine_playwright from metagpt.utils.parse_html import WebPage @@ -19,26 +18,22 @@ ids=["chromium-normal", "firefox-normal", "webkit-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd): - global_proxy = config.proxy - try: - if use_proxy: - server, proxy_url = await proxy() - config.proxy = proxy_url - browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs) - result = await browser.run(url) - assert isinstance(result, WebPage) - assert "MetaGPT" in result.inner_text + proxy_url = None + if use_proxy: + server, proxy_url = await proxy() + browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs) + result = await browser.run(url) + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text - if urls: - results = await browser.run(url, *urls) - assert isinstance(results, list) - assert len(results) == len(urls) + 1 - assert all(("MetaGPT" in i.inner_text) for i in results) - if use_proxy: - server.close() - assert "Proxy:" in capfd.readouterr().out - finally: - config.proxy = global_proxy + if urls: + results = await browser.run(url, *urls) + assert isinstance(results, list) + assert len(results) == len(urls) + 1 + assert all(("MetaGPT" in i.inner_text) for i in results) + if use_proxy: + server.close() + assert "Proxy:" in capfd.readouterr().out if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index 1b1439d29..a88a5d0f4 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -4,7 +4,6 @@ import browsers import pytest -from metagpt.config2 import config from metagpt.tools import web_browser_engine_selenium from metagpt.utils.parse_html import WebPage @@ -40,27 +39,22 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd): # Prerequisites # firefox, chrome, Microsoft Edge - - global_proxy = config.proxy - try: - if use_proxy: - server, proxy_url = await proxy() - config.proxy = proxy_url - browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type) - result = await browser.run(url) - assert isinstance(result, WebPage) - assert "MetaGPT" in result.inner_text - - if urls: - results = await browser.run(url, *urls) - assert isinstance(results, list) - assert len(results) == len(urls) + 1 - assert all(("MetaGPT" in i.inner_text) for i in results) - if use_proxy: - server.close() - assert "Proxy:" in capfd.readouterr().out - finally: - config.proxy = global_proxy + proxy_url = None + if use_proxy: + server, proxy_url = await proxy() + browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url) + result = await browser.run(url) + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text + + if urls: + results = await browser.run(url, *urls) + assert isinstance(results, list) + assert len(results) == len(urls) + 1 + assert all(("MetaGPT" in i.inner_text) for i in results) + if use_proxy: + server.close() + assert "Proxy:" in capfd.readouterr().out if __name__ == "__main__": diff --git a/tests/mock/mock_aiohttp.py b/tests/mock/mock_aiohttp.py index 49dcdba79..a7c022a4b 100644 --- a/tests/mock/mock_aiohttp.py +++ b/tests/mock/mock_aiohttp.py @@ -13,7 +13,8 @@ class MockAioResponse: def __init__(self, session, method, url, **kwargs) -> None: fn = self.check_funcs.get((method, url)) - self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}" + _kwargs = {k: v for k, v in kwargs.items() if k != "proxy"} + self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(_kwargs, sort_keys=True)}" self.mng = self.response = None if self.key not in self.rsp_cache: self.mng = origin_request(session, method, url, **kwargs)