Skip to content

Commit

Permalink
fixed an issue where solo nodes would throw an error when building co…
Browse files Browse the repository at this point in the history
…mmunities
  • Loading branch information
prasmussen15 committed Sep 23, 2024
1 parent 7a43c41 commit 3d60056
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
20 changes: 10 additions & 10 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()

if not use_bulk:
for i, message in enumerate(messages[3:14]):
for i, message in enumerate(messages[3:4]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
Expand All @@ -76,15 +76,15 @@ async def main(use_bulk: bool = True):
await client.build_communities()

# add additional messages to update communities
for i, message in enumerate(messages[14:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)
# for i, message in enumerate(messages[14:20]):
# await client.add_episode(
# name=f'Message {i}',
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
# reference_time=message.actual_timestamp,
# source_description='Podcast Transcript',
# group_id='1',
# update_communities=True,
# )

return

Expand Down
13 changes: 7 additions & 6 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
]

community_lst.sort(reverse=True)
community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1

new_community = max(community_lst[0][1], curr_community)
new_community = max(community_candidate, curr_community)

new_community_map[uuid] = new_community

Expand Down Expand Up @@ -150,7 +151,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s


async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
Expand All @@ -164,7 +165,7 @@ async def build_community(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2) :]
summaries[: int(length / 2)], summaries[int(length / 2):]
)
]
)
Expand Down Expand Up @@ -192,7 +193,7 @@ async def build_community(


async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver)

Expand All @@ -219,7 +220,7 @@ async def remove_communities(driver: AsyncDriver):


async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -280,7 +281,7 @@ async def determine_entity_community(


async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)

Expand Down

0 comments on commit 3d60056

Please sign in to comment.