Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In memory label propagation community detection #136

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

<img width="350" alt="Graphiti-ts-small" src="https://github.com/user-attachments/assets/bbd02947-e435-4a05-b25a-bbbac36d52c8">


## Temporal Knowledge Graphs for Agentic Applications

<br />
Expand Down Expand Up @@ -80,7 +79,6 @@ Requirements:

- Python 3.10 or higher
- Neo4j 5.21 or higher
- Neo4j GraphDataScience Plugin (required for community flows)
- OpenAI API key (for LLM inference and embedding)
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved

Optional:
Expand Down
20 changes: 10 additions & 10 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()

if not use_bulk:
for i, message in enumerate(messages[3:14]):
for i, message in enumerate(messages[3:4]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
Expand All @@ -76,15 +76,15 @@ async def main(use_bulk: bool = True):
await client.build_communities()

# add additional messages to update communities
for i, message in enumerate(messages[14:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)
# for i, message in enumerate(messages[14:20]):
# await client.add_episode(
# name=f'Message {i}',
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
# reference_time=message.actual_timestamp,
# source_description='Podcast Transcript',
# group_id='1',
# update_communities=True,
# )

return

Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ async def search(
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
):
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.

Expand Down
124 changes: 94 additions & 30 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime

from neo4j import AsyncDriver
from pydantic import BaseModel

from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient
Expand All @@ -14,6 +15,11 @@
logger = logging.getLogger(__name__)


class Neighbor(BaseModel):
node_uuid: str
edge_count: int


async def build_community_projection(driver: AsyncDriver) -> str:
records, _, _ = await driver.execute_query("""
CALL gds.graph.project("communities", "Entity",
Expand All @@ -29,36 +35,96 @@ async def build_community_projection(driver: AsyncDriver) -> str:
return records[0]['graph']


async def destroy_projection(driver: AsyncDriver, projection_name: str):
await driver.execute_query(
"""
CALL gds.graph.drop($projection_name)
""",
projection_name=projection_name,
)

async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
community_clusters: list[list[EntityNode]] = []

async def get_community_clusters(
driver: AsyncDriver, projection_name: str
) -> list[list[EntityNode]]:
records, _, _ = await driver.execute_query("""
CALL gds.leiden.stream("communities")
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
group_id_values, _, _ = await driver.execute_query("""
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
RETURN
collect(DISTINCT n.group_id) AS group_ids
""")
community_map: dict[int, list[str]] = defaultdict(list)
for record in records:
community_map[record['communityId']].append(record['entity_uuid'])

community_clusters: list[list[EntityNode]] = list(
await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
group_ids = group_id_values[0]['group_ids']
for group_id in group_ids:
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
WITH count(r) AS count, m.uuid AS uuid
RETURN
uuid,
count
""",
uuid=node.uuid,
group_id=group_id,
)

projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]

cluster_uuids = label_propagation(projection)

community_clusters.extend(
list(
await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
)
)
)
)

return community_clusters


def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved
# Implement the label propagation community detection algorithm.
# 1. Start with each node being assigned its own community
# 2. Each node will take on the community of the plurality of its neighbors
# 3. Ties are broken by going to the largest community
# 4. Continue until no communities change during propagation

community_map = {uuid: i for i, uuid in enumerate(projection.keys())}

while True:
no_change = True
new_community_map: dict[str, int] = {}

for uuid, neighbors in projection.items():
curr_community = community_map[uuid]

community_candidates: dict[int, int] = defaultdict(int)
for neighbor in neighbors:
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count

community_lst = [
(count, community) for community, count in community_candidates.items()
]

community_lst.sort(reverse=True)
community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1

new_community = max(community_candidate, curr_community)

new_community_map[uuid] = new_community

if new_community != curr_community:
no_change = False

if no_change:
break

community_map = new_community_map

community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
community_cluster_map[community].append(uuid)

clusters = [cluster for cluster in community_cluster_map.values()]
return clusters


async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
# Prepare context for LLM
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
Expand All @@ -85,7 +151,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s


async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
Expand All @@ -99,7 +165,7 @@ async def build_community(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2) :]
summaries[: int(length / 2)], summaries[int(length / 2):]
)
]
)
Expand Down Expand Up @@ -127,10 +193,9 @@ async def build_community(


async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
projection = await build_community_projection(driver)
community_clusters = await get_community_clusters(driver, projection)
community_clusters = await get_community_clusters(driver)

communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await asyncio.gather(
Expand All @@ -144,7 +209,6 @@ async def build_communities(
community_nodes.append(community[0])
community_edges.extend(community[1])

await destroy_projection(driver, projection)
return community_nodes, community_edges


Expand All @@ -156,7 +220,7 @@ async def remove_communities(driver: AsyncDriver):


async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -217,7 +281,7 @@ async def determine_entity_community(


async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)

Expand Down
1 change: 1 addition & 0 deletions tests/test_graphiti_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()

edges = await graphiti.search('tania tetlow', group_ids=['1'])

Expand Down