Skip to content

Commit

Permalink
in memory graph detection
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Sep 21, 2024
1 parent 1aa7bf9 commit 2c524a6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 53 deletions.
68 changes: 34 additions & 34 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@

class Graphiti:
def __init__(
self,
uri: str,
user: str,
password: str,
llm_client: LLMClient | None = None,
store_raw_episode_content: bool = True,
self,
uri: str,
user: str,
password: str,
llm_client: LLMClient | None = None,
store_raw_episode_content: bool = True,
):
"""
Initialize a Graphiti instance.
Expand Down Expand Up @@ -194,10 +194,10 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -227,15 +227,15 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -574,11 +574,11 @@ async def build_communities(self):
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])

async def search(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
Expand Down Expand Up @@ -630,22 +630,22 @@ async def search(
return edges

async def _search(
self,
query: str,
config: SearchConfig,
group_ids: list[str | None] | None = None,
center_node_uuid: str | None = None,
self,
query: str,
config: SearchConfig,
group_ids: list[str | None] | None = None,
center_node_uuid: str | None = None,
) -> SearchResults:
return await search(
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
)

async def get_nodes_by_query(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
limit: int = DEFAULT_SEARCH_LIMIT,
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand Down
44 changes: 25 additions & 19 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ async def destroy_projection(driver: AsyncDriver, projection_name: str):
)


async def get_community_clusters(
driver: AsyncDriver
) -> list[list[EntityNode]]:
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
community_clusters: list[list[EntityNode]] = []

group_id_values, _, _ = await driver.execute_query("""
Expand All @@ -60,16 +58,21 @@ async def get_community_clusters(
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
records, _, _ = await driver.execute_query("""
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
WITH count(r) AS count, m.uuid AS uuid
RETURN
uuid,
count
""", uuid=node.uuid, group_id=group_id)
""",
uuid=node.uuid,
group_id=group_id,
)

projection[node.uuid] = [Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in
records]
projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]

cluster_uuids = label_propagation(projection)

Expand All @@ -96,27 +99,30 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
for neighbor in neighbors:
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count

community_lst = [(count, community) for community, count in community_candidates.items()]
community_lst = [
(count, community) for community, count in community_candidates.items()
]

community_lst.sort(reverse=True)

new_community = max(community_lst[0][1], curr_community)

new_community_map[uuid] = new_community

if new_community != curr_community:
no_change = False
new_community_map[uuid] = new_community

if no_change:
break

community_map = new_community_map

community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
community_cluster_map[community].append(uuid)
community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
community_cluster_map[community].append(uuid)

clusters = [cluster for cluster in community_cluster_map.values()]
return clusters
clusters = [cluster for cluster in community_cluster_map.values()]
return clusters


async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
Expand Down Expand Up @@ -145,7 +151,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s


async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
Expand All @@ -159,7 +165,7 @@ async def build_community(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2):]
summaries[: int(length / 2)], summaries[int(length / 2) :]
)
]
)
Expand Down Expand Up @@ -187,7 +193,7 @@ async def build_community(


async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver)

Expand All @@ -214,7 +220,7 @@ async def remove_communities(driver: AsyncDriver):


async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -275,7 +281,7 @@ async def determine_entity_community(


async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)

Expand Down

0 comments on commit 2c524a6

Please sign in to comment.