Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mentions reranker #124

Merged
merged 9 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)

return
Expand Down
79 changes: 77 additions & 2 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
raise EdgeNotFoundError(uuid)
return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid INN $uuids
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)

edges = [get_episodic_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges


class EntityEdge(Edge):
name: str = Field(description='name of the edge, relation name')
fact: str = Field(description='fact representing the edge and nodes that it connects')
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
episodes: list[str] | None = Field(
default=None,
episodes: list[str] = Field(
default=[],
description='list of episode ids that reference these entity edges',
)
expired_at: datetime | None = Field(
Expand Down Expand Up @@ -197,6 +220,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
raise EdgeNotFoundError(uuid)
return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
uuids=uuids,
)

edges = [get_entity_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges


class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
Expand Down Expand Up @@ -239,6 +292,28 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):

return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)

edges = [get_community_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')

return edges


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
Expand Down
22 changes: 20 additions & 2 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
)
Expand Down Expand Up @@ -249,8 +251,6 @@ async def add_episode(
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
update_communities: bool
Optional. Determines if we should update communities

Returns
-------
Expand Down Expand Up @@ -413,6 +413,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Built episodic edges: {episodic_edges}')

episode.entity_edges = [edge.uuid for edge in entity_edges]

# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
Expand Down Expand Up @@ -680,3 +682,19 @@ async def get_nodes_by_query(
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes


async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)

edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]

nodes = await get_mentioned_nodes(self.driver, episodes)

communities = await get_communities_by_nodes(self.driver, nodes)

return SearchResults(edges=edges, nodes=nodes, communities=communities)
3 changes: 2 additions & 1 deletion graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
RETURN e.content AS content,
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
Expand Down
8 changes: 7 additions & 1 deletion graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
community_similarity_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
Expand Down Expand Up @@ -131,7 +132,7 @@ async def edge_search(
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf:
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
Expand All @@ -150,6 +151,9 @@ async def edge_search(

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]

if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))

return reranked_edges


Expand Down Expand Up @@ -189,6 +193,8 @@ async def node_search(
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
Expand Down
2 changes: 2 additions & 0 deletions graphiti_core/search/search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'


class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'


class CommunityReranker(Enum):
Expand Down
16 changes: 16 additions & 0 deletions graphiti_core/search/search_config_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@
)
)

# performs a hybrid search over edges with episode mention reranking
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.episode_mentions,
)
)

# performs a hybrid search over nodes with rrf reranking
NODE_HYBRID_SEARCH_RRF = SearchConfig(
node_config=NodeSearchConfig(
Expand All @@ -75,6 +83,14 @@
)
)

# performs a hybrid search over nodes with episode mentions reranking
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.episode_mentions,
)
)

# performs a hybrid search over communities with rrf reranking
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
community_config=CommunitySearchConfig(
Expand Down
58 changes: 57 additions & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
RELEVANT_SCHEMA_LIMIT = 3


async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
Expand All @@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
return nodes


async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
)

communities = [get_community_node_from_record(record) for record in records]

return communities


async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
Expand Down Expand Up @@ -634,3 +659,34 @@ async def node_distance_reranker(
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids


async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}

# Find the shortest path to center node
query = Query("""
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
RETURN count(*) AS score
""")

result_scores = await asyncio.gather(
*[
driver.execute_query(
query,
node_uuid=uuid,
)
for uuid in sorted_uuids
]
)

for uuid, result in zip(sorted_uuids, result_scores):
record = result[0][0]
scores[uuid] = record['score']

# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids
2 changes: 2 additions & 0 deletions graphiti_core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
if edge.uuid in duplicate_uuid_map:
existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
# Add current episode to the episodes list
existing_edge.episodes += edge.episodes
edges.append(existing_edge)
else:
edges.append(edge)
Expand Down