Skip to content

Commit

Permalink
add subscription deduplication (#594)
Browse files Browse the repository at this point in the history
* add subscription deduplication

* format

---------

Co-authored-by: Mohammad Mazraeh <mmazraeh@microsoft.com>
Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 58ee8b7 commit 1ac5272
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext
from autogen_core.components import DefaultTopicId, RoutedAgent, message_handler
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components import DefaultSubscription
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self) -> None:

async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub.id == subscription.id for sub in self._subscriptions):
if any(sub == subscription for sub in self._subscriptions):
raise ValueError("Subscription already exists")

self._subscriptions.append(subscription)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,11 @@ def map_to_agent(self, topic_id: TopicId) -> AgentId:

return AgentId(type=self._agent_type, key=topic_id.source)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TypeSubscription):
return False

return self.id == other.id or (self.agent_type == other.agent_type and self.topic_type == other.topic_type)


BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
2 changes: 1 addition & 1 deletion python/packages/autogen-core/tests/test_closure_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentRuntime, MessageContext
from autogen_core.components import ClosureAgent
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components import DefaultSubscription
from autogen_core.components._default_topic import DefaultTopicId


Expand Down
20 changes: 20 additions & 0 deletions python/packages/autogen-core/tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from autogen_core.base import AgentId, TopicId
from autogen_core.base.exceptions import CantHandleException
from autogen_core.components import DefaultTopicId, TypeSubscription
from autogen_core.components import DefaultSubscription
from test_utils import LoopbackAgent, MessageType


Expand Down Expand Up @@ -96,3 +97,22 @@ async def test_skipped_class_subscriptions() -> None:
AgentId("MyAgent", key="default"), type=LoopbackAgent
)
assert agent_instance.num_calls == 0


@pytest.mark.asyncio
async def test_subscription_deduplication() -> None:
runtime = SingleThreadedAgentRuntime()
agent_type = "MyAgent"

# Test TypeSubscription
type_subscription_1 = TypeSubscription("default", agent_type)
type_subscription_2 = TypeSubscription("default", agent_type)

await runtime.add_subscription(type_subscription_1)
with pytest.raises(ValueError, match="Subscription already exists"):
await runtime.add_subscription(type_subscription_2)

# Test DefaultSubscription
default_subscription = DefaultSubscription(agent_type=agent_type)
with pytest.raises(ValueError, match="Subscription already exists"):
await runtime.add_subscription(default_subscription)

0 comments on commit 1ac5272

Please sign in to comment.