From 5316bc455aafefe7c9f26bd4308cc1095f89dd5d Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Tue, 17 Sep 2024 21:47:15 -0400 Subject: [PATCH] add communities to mentions endpoint --- graphiti_core/graphiti.py | 17 +++++++++++++++++ graphiti_core/search/search_utils.py | 23 +++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ac53ea12..91cb94cb 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -35,6 +35,7 @@ ) from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, + get_communities_by_nodes, get_mentioned_nodes, get_relevant_edges, get_relevant_nodes, @@ -681,3 +682,19 @@ async def get_nodes_by_query( await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid) ).nodes return nodes + + +async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults: + episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) + + edges_list = await asyncio.gather( + *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] + ) + + edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst] + + nodes = await get_mentioned_nodes(self.driver, episodes) + + communities = await get_communities_by_nodes(self.driver, nodes) + + return SearchResults(edges=edges, nodes=nodes, communities=communities) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 7bb27d2b..0cc19be5 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -59,6 +59,29 @@ async def get_mentioned_nodes( return nodes +async def get_communities_by_nodes( + driver: AsyncDriver, nodes: list[EntityNode] +) -> list[CommunityNode]: + node_uuids = [node.uuid for node in nodes] + records, _, _ = await driver.execute_query( + """ + MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids + RETURN DISTINCT + c.uuid As uuid, + c.group_id AS group_id, + c.name AS name, + c.name_embedding AS name_embedding + c.created_at AS created_at, + c.summary AS summary + """, + uuids=node_uuids, + ) + + communities = [get_community_node_from_record(record) for record in records] + + return communities + + async def edge_fulltext_search( driver: AsyncDriver, query: str,