Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In memory label propagation community detection #136

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ async def search(
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
):
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.

Expand Down
98 changes: 81 additions & 17 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime

from neo4j import AsyncDriver
from pydantic import BaseModel

from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient
Expand All @@ -14,6 +15,11 @@
logger = logging.getLogger(__name__)


class Neighbor(BaseModel):
node_uuid: str
edge_count: int


async def build_community_projection(driver: AsyncDriver) -> str:
records, _, _ = await driver.execute_query("""
CALL gds.graph.project("communities", "Entity",
Expand All @@ -38,27 +44,87 @@
)


async def get_community_clusters(
driver: AsyncDriver, projection_name: str
) -> list[list[EntityNode]]:
records, _, _ = await driver.execute_query("""
CALL gds.leiden.stream("communities")
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
community_clusters: list[list[EntityNode]] = []

group_id_values, _, _ = await driver.execute_query("""
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
RETURN
collect(DISTINCT n.group_id) AS group_ids
""")
community_map: dict[int, list[str]] = defaultdict(list)
for record in records:
community_map[record['communityId']].append(record['entity_uuid'])

community_clusters: list[list[EntityNode]] = list(
await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
group_ids = group_id_values[0]['group_ids']
for group_id in group_ids:
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
WITH count(r) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
uuid=node.uuid,
group_id=group_id,
)

projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]

cluster_uuids = label_propagation(projection)

community_clusters = list(
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
)
)
)

return community_clusters


def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}

while True:
no_change = True
new_community_map: dict[str, int] = {}

for uuid, neighbors in projection.items():
curr_community = community_map[uuid]

community_candidates = defaultdict(int)

Check failure on line 98 in graphiti_core/utils/maintenance/community_operations.py

View workflow job for this annotation

GitHub Actions / mypy

var-annotated

Need type annotation for "community_candidates"
for neighbor in neighbors:
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count

community_lst = [
(count, community) for community, count in community_candidates.items()
]

community_lst.sort(reverse=True)

new_community = max(community_lst[0][1], curr_community)

new_community_map[uuid] = new_community

if new_community != curr_community:
no_change = False

if no_change:
break

community_map = new_community_map

community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
community_cluster_map[community].append(uuid)

clusters = [cluster for cluster in community_cluster_map.values()]
return clusters


async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
# Prepare context for LLM
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
Expand Down Expand Up @@ -129,8 +195,7 @@
async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
projection = await build_community_projection(driver)
community_clusters = await get_community_clusters(driver, projection)
community_clusters = await get_community_clusters(driver)

communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await asyncio.gather(
Expand All @@ -144,7 +209,6 @@
community_nodes.append(community[0])
community_edges.extend(community[1])

await destroy_projection(driver, projection)
return community_nodes, community_edges


Expand Down
2 changes: 2 additions & 0 deletions tests/test_graphiti_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import asyncio
import json

Check failure on line 18 in tests/test_graphiti_int.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/test_graphiti_int.py:18:8: F401 `json` imported but unused
import logging
import os
import sys
Expand Down Expand Up @@ -74,6 +75,7 @@
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()

edges = await graphiti.search('tania tetlow', group_ids=['1'])

Expand Down
Loading