diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 22e1d90b..90a4a205 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:14]): + for i, message in enumerate(messages[3:4]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', @@ -76,15 +76,15 @@ async def main(use_bulk: bool = True): await client.build_communities() # add additional messages to update communities - for i, message in enumerate(messages[14:20]): - await client.add_episode( - name=f'Message {i}', - episode_body=f'{message.speaker_name} ({message.role}): {message.content}', - reference_time=message.actual_timestamp, - source_description='Podcast Transcript', - group_id='1', - update_communities=True, - ) + # for i, message in enumerate(messages[14:20]): + # await client.add_episode( + # name=f'Message {i}', + # episode_body=f'{message.speaker_name} ({message.role}): {message.content}', + # reference_time=message.actual_timestamp, + # source_description='Podcast Transcript', + # group_id='1', + # update_communities=True, + # ) return diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index b72c995f..a239b691 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -103,8 +103,9 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: ] community_lst.sort(reverse=True) + community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1 - new_community = max(community_lst[0][1], curr_community) + new_community = max(community_candidate, curr_community) new_community_map[uuid] = new_community @@ -150,7 +151,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s async def build_community( - llm_client: LLMClient, community_cluster: list[EntityNode] + llm_client: LLMClient, community_cluster: list[EntityNode] ) -> tuple[CommunityNode, list[CommunityEdge]]: summaries = [entity.summary for entity in community_cluster] length = len(summaries) @@ -164,7 +165,7 @@ async def build_community( *[ summarize_pair(llm_client, (str(left_summary), str(right_summary))) for left_summary, right_summary in zip( - summaries[: int(length / 2)], summaries[int(length / 2) :] + summaries[: int(length / 2)], summaries[int(length / 2):] ) ] ) @@ -192,7 +193,7 @@ async def build_community( async def build_communities( - driver: AsyncDriver, llm_client: LLMClient + driver: AsyncDriver, llm_client: LLMClient ) -> tuple[list[CommunityNode], list[CommunityEdge]]: community_clusters = await get_community_clusters(driver) @@ -219,7 +220,7 @@ async def remove_communities(driver: AsyncDriver): async def determine_entity_community( - driver: AsyncDriver, entity: EntityNode + driver: AsyncDriver, entity: EntityNode ) -> tuple[CommunityNode | None, bool]: # Check if the node is already part of a community records, _, _ = await driver.execute_query( @@ -280,7 +281,7 @@ async def determine_entity_community( async def update_community( - driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode + driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode ): community, is_new = await determine_entity_community(driver, entity)