diff --git a/examples/distributed_debate/user_proxy_agent.py b/examples/distributed_debate/user_proxy_agent.py index 37f4b6d28..7a5d819eb 100644 --- a/examples/distributed_debate/user_proxy_agent.py +++ b/examples/distributed_debate/user_proxy_agent.py @@ -22,7 +22,7 @@ def reply( # type: ignore[override] self.speak(x) return super().reply(x, required_keys) - def observe(self, x: Union[dict, Sequence[dict]]) -> None: + def observe(self, x: Union[Msg, Sequence[Msg]]) -> None: """ Observe with `self.speak(x)` """ diff --git a/src/agentscope/agents/agent.py b/src/agentscope/agents/agent.py index da7794f10..bfe7e99c4 100644 --- a/src/agentscope/agents/agent.py +++ b/src/agentscope/agents/agent.py @@ -345,11 +345,11 @@ def speak( f"object, got {type(content)} instead.", ) - def observe(self, x: Union[dict, Sequence[dict]]) -> None: + def observe(self, x: Union[Msg, Sequence[Msg]]) -> None: """Observe the input, store it in memory without response to it. Args: - x (`Union[dict, Sequence[dict]]`): + x (`Union[Msg, Sequence[Msg]]`): The input message to be recorded in memory. """ if self.memory: diff --git a/src/agentscope/agents/rpc_agent.py b/src/agentscope/agents/rpc_agent.py index 0b3cf245e..48604a9c3 100644 --- a/src/agentscope/agents/rpc_agent.py +++ b/src/agentscope/agents/rpc_agent.py @@ -118,7 +118,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: x=x, ) - def observe(self, x: Union[dict, Sequence[dict]]) -> None: + def observe(self, x: Union[Msg, Sequence[Msg]]) -> None: if self.client is None: self._launch_server() self.client.call_agent_func( diff --git a/src/agentscope/logging.py b/src/agentscope/logging.py index 0ac5a7c48..a4c4a5f4c 100644 --- a/src/agentscope/logging.py +++ b/src/agentscope/logging.py @@ -7,9 +7,10 @@ from loguru import logger -from agentscope.message import Msg -from agentscope.studio._client import _studio_client -from agentscope.web.gradio.utils import ( +from .utils.tools import _guess_type_by_extension +from .message import Msg +from .studio._client import _studio_client +from .web.gradio.utils import ( generate_image_from_name, send_msg, get_reset_msg, @@ -120,56 +121,44 @@ def log_msg(msg: Msg, disable_gradio: bool = False) -> None: _save_msg(msg) -def log_gradio(message: dict, uid: str, **kwargs: Any) -> None: +def log_gradio(msg: Msg, uid: str, **kwargs: Any) -> None: """Send chat message to studio. Args: - message (`dict`): - The message to be logged. It should have "name"(or "role") and - "content" keys, and the message will be logged as ": - ". + msg (`Msg`): + The message to be logged. uid (`str`): The local value 'uid' of the thread. """ if uid: get_reset_msg(uid=uid) - name = message.get("name", "default") or message.get("role", "default") avatar = kwargs.get("avatar", None) or generate_image_from_name( - message["name"], + msg.name, ) - msg = message["content"] + content = msg.content flushing = True - if "url" in message and message["url"]: - flushing = False - if isinstance(message["url"], str): - message["url"] = [message["url"]] - for i in range(len(message["url"])): - msg += "\n" + f"""""" - if "audio_path" in message and message["audio_path"]: - flushing = False - if isinstance(message["audio_path"], str): - message["audio_path"] = [message["audio_path"]] - for i in range(len(message["audio_path"])): - msg += ( - "\n" - + f"""""" - ) - if "video_path" in message and message["video_path"]: + if msg.url is not None: flushing = False - if isinstance(message["video_path"], str): - message["video_path"] = [message["video_path"]] - for i in range(len(message["video_path"])): - msg += ( - "\n" - + f"""""" - ) + if isinstance(msg.url, str): + urls = [msg.url] + else: + urls = msg.url + + for url in urls: + typ = _guess_type_by_extension(url) + if typ == "image": + content += f"\n" + elif typ == "audio": + content += f"\n" + elif typ == "video": + content += f"\n" + else: + content += f"\n{url}" send_msg( - msg, - role=name, + content, + role=msg.name, uid=uid, flushing=flushing, avatar=avatar, diff --git a/src/agentscope/msghub.py b/src/agentscope/msghub.py index ab8b440a5..68e2049e9 100644 --- a/src/agentscope/msghub.py +++ b/src/agentscope/msghub.py @@ -6,7 +6,8 @@ from loguru import logger -from agentscope.agents import AgentBase +from .agents import AgentBase +from .message import Msg class MsgHubManager: @@ -15,7 +16,7 @@ class MsgHubManager: def __init__( self, participants: Sequence[AgentBase], - announcement: Optional[Union[Sequence[dict], dict]] = None, + announcement: Optional[Union[Sequence[Msg], Msg]] = None, ) -> None: """Initialize a msghub manager from the given arguments. @@ -23,7 +24,7 @@ def __init__( participants (`Sequence[AgentBase]`): The Sequence of participants in the msghub. announcement - (`Optional[Union[list[dict], dict]]`, defaults to `None`): + (`Optional[Union[list[Msg], Msg]]`, defaults to `None`): The message that will be broadcast to all participants at the first without requiring response. """ @@ -102,11 +103,11 @@ def delete( # Remove this agent from the audience of other agents self._reset_audience() - def broadcast(self, msg: Union[dict, list[dict]]) -> None: + def broadcast(self, msg: Union[Msg, Sequence[Msg]]) -> None: """Broadcast the message to all participants. Args: - msg (`Union[dict, list[dict]]`): + msg (`Union[Msg, Sequence[Msg]]`): One or a list of dict messages to broadcast among all participants. """ @@ -116,14 +117,14 @@ def broadcast(self, msg: Union[dict, list[dict]]) -> None: def msghub( participants: Sequence[AgentBase], - announcement: Optional[Union[Sequence[dict], dict]] = None, + announcement: Optional[Union[Sequence[Msg], Msg]] = None, ) -> MsgHubManager: """msghub is used to share messages among a group of agents. Args: participants (`Sequence[AgentBase]`): A Sequence of participated agents in the msghub. - announcement (`Optional[Union[list[dict], dict]]`, defaults to `None`): + announcement (`Optional[Union[list[Msg], Msg]]`, defaults to `None`): The message that will be broadcast to all participants at the very beginning without requiring response. diff --git a/tests/rpc_agent_test.py b/tests/rpc_agent_test.py index bbce7c95d..7f0a4ae8a 100644 --- a/tests/rpc_agent_test.py +++ b/tests/rpc_agent_test.py @@ -153,7 +153,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: class FileAgent(AgentBase): """An agent returns a file""" - def reply(self, x: dict = None) -> dict: + def reply(self, x: Msg = None) -> Msg: image_path = os.path.abspath( os.path.join( os.path.abspath(os.path.dirname(__file__)),