Skip to content

Commit

Permalink
add communities to mentions endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Sep 18, 2024
1 parent 1644e4b commit 5316bc4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
17 changes: 17 additions & 0 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
Expand Down Expand Up @@ -681,3 +682,19 @@ async def get_nodes_by_query(
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes


async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)

edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]

nodes = await get_mentioned_nodes(self.driver, episodes)

communities = await get_communities_by_nodes(self.driver, nodes)

return SearchResults(edges=edges, nodes=nodes, communities=communities)
23 changes: 23 additions & 0 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ async def get_mentioned_nodes(
return nodes


async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
)

communities = [get_community_node_from_record(record) for record in records]

return communities


async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
Expand Down

0 comments on commit 5316bc4

Please sign in to comment.