Skip to content

Commit

Permalink
Mentions reranker (#124)
Browse files Browse the repository at this point in the history
* documentation update

* update communities

* mentions reranker

* fix episode edge mentions

* get episode mentions

* add communities to mentions endpoint

* rebase

* defaults episodes to empty list

* update
  • Loading branch information
prasmussen15 committed Sep 18, 2024
1 parent d133c39 commit e398f95
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 7 deletions.
1 change: 1 addition & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)

return
Expand Down
79 changes: 77 additions & 2 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
raise EdgeNotFoundError(uuid)
return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)

edges = [get_episodic_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges


class EntityEdge(Edge):
name: str = Field(description='name of the edge, relation name')
fact: str = Field(description='fact representing the edge and nodes that it connects')
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
episodes: list[str] | None = Field(
default=None,
episodes: list[str] = Field(
default=[],
description='list of episode ids that reference these entity edges',
)
expired_at: datetime | None = Field(
Expand Down Expand Up @@ -197,6 +220,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
raise EdgeNotFoundError(uuid)
return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
uuids=uuids,
)

edges = [get_entity_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges


class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
Expand Down Expand Up @@ -239,6 +292,28 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):

return edges[0]

@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)

edges = [get_community_edge_from_record(record) for record in records]

logger.info(f'Found Edges: {uuids}')

return edges


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
Expand Down
22 changes: 20 additions & 2 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
)
Expand Down Expand Up @@ -249,8 +251,6 @@ async def add_episode(
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
update_communities: bool
Optional. Determines if we should update communities
Returns
-------
Expand Down Expand Up @@ -413,6 +413,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Built episodic edges: {episodic_edges}')

episode.entity_edges = [edge.uuid for edge in entity_edges]

# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
Expand Down Expand Up @@ -680,3 +682,19 @@ async def get_nodes_by_query(
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes


async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)

edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]

nodes = await get_mentioned_nodes(self.driver, episodes)

communities = await get_communities_by_nodes(self.driver, nodes)

return SearchResults(edges=edges, nodes=nodes, communities=communities)
3 changes: 2 additions & 1 deletion graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
RETURN e.content AS content,
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
Expand Down
8 changes: 7 additions & 1 deletion graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
community_similarity_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
Expand Down Expand Up @@ -131,7 +132,7 @@ async def edge_search(
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf:
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
Expand All @@ -150,6 +151,9 @@ async def edge_search(

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]

if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))

return reranked_edges


Expand Down Expand Up @@ -189,6 +193,8 @@ async def node_search(
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
Expand Down
2 changes: 2 additions & 0 deletions graphiti_core/search/search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'


class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'


class CommunityReranker(Enum):
Expand Down
16 changes: 16 additions & 0 deletions graphiti_core/search/search_config_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@
)
)

# performs a hybrid search over edges with episode mention reranking
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.episode_mentions,
)
)

# performs a hybrid search over nodes with rrf reranking
NODE_HYBRID_SEARCH_RRF = SearchConfig(
node_config=NodeSearchConfig(
Expand All @@ -75,6 +83,14 @@
)
)

# performs a hybrid search over nodes with episode mentions reranking
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.episode_mentions,
)
)

# performs a hybrid search over communities with rrf reranking
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
community_config=CommunitySearchConfig(
Expand Down
58 changes: 57 additions & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
RELEVANT_SCHEMA_LIMIT = 3


async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
Expand All @@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
return nodes


async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
)

communities = [get_community_node_from_record(record) for record in records]

return communities


async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
Expand Down Expand Up @@ -634,3 +659,34 @@ async def node_distance_reranker(
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids


async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}

# Find the shortest path to center node
query = Query("""
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
RETURN count(*) AS score
""")

result_scores = await asyncio.gather(
*[
driver.execute_query(
query,
node_uuid=uuid,
)
for uuid in sorted_uuids
]
)

for uuid, result in zip(sorted_uuids, result_scores):
record = result[0][0]
scores[uuid] = record['score']

# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids
2 changes: 2 additions & 0 deletions graphiti_core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
if edge.uuid in duplicate_uuid_map:
existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
# Add current episode to the episodes list
existing_edge.episodes += edge.episodes
edges.append(existing_edge)
else:
edges.append(edge)
Expand Down

0 comments on commit e398f95

Please sign in to comment.