Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search refactor + Community search #111

Merged
merged 12 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()

if not use_bulk:
for i, message in enumerate(messages[3:130]):
for i, message in enumerate(messages[3:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
Expand Down
25 changes: 25 additions & 0 deletions graphiti_core/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


class GraphitiError(Exception):
"""Base exception class for Graphiti Core."""

Expand All @@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError):
def __init__(self, uuid: str):
self.message = f'node {uuid} not found'
super().__init__(self.message)


class SearchRerankerError(GraphitiError):
"""Raised when a node is not found."""

def __init__(self, text: str):
self.message = text
super().__init__(self.message)
58 changes: 33 additions & 25 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@

from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.utils import generate_embedding
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF,
NODE_HYBRID_SEARCH_NODE_DISTANCE,
NODE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_relevant_edges,
get_relevant_nodes,
hybrid_node_search,
)
from graphiti_core.utils import (
build_episodic_edges,
Expand Down Expand Up @@ -548,7 +553,7 @@ async def search(
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
num_results=DEFAULT_SEARCH_LIMIT,
):
"""
Perform a hybrid search on the knowledge graph.
Expand All @@ -564,7 +569,7 @@ async def search(
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional
limit : int, optional
The maximum number of results to return. Defaults to 10.

Returns
Expand All @@ -581,21 +586,17 @@ async def search(
The search is performed using the current date and time as the reference
point for temporal relevance.
"""
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
search_config = (
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = num_results

edges = (
await hybrid_search(
await search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
group_ids,
search_config,
center_node_uuid,
)
Expand All @@ -606,19 +607,20 @@ async def search(
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
group_ids: list[str | None] | None = None,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
) -> SearchResults:
return await search(
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
)

async def get_nodes_by_query(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand All @@ -629,7 +631,9 @@ async def get_nodes_by_query(
Parameters
----------
query : str
The text query to search for in the graph.
The text query to search for in the graph
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
Expand All @@ -655,8 +659,12 @@ async def get_nodes_by_query(
If not specified, a default limit (defined in the search functions) will be used.
"""
embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
search_config = (
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
)
return relevant_nodes
search_config.limit = limit

nodes = (
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes
16 changes: 16 additions & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from datetime import datetime

from neo4j import time as neo4j_time
Expand Down
17 changes: 17 additions & 0 deletions graphiti_core/llm_client/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


class RateLimitError(Exception):
"""Exception raised when the rate limit is exceeded."""

Expand Down
18 changes: 17 additions & 1 deletion graphiti_core/llm_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
import typing
from time import time
Expand All @@ -17,6 +33,6 @@ async def generate_embedding(
embedding = embedding[:EMBEDDING_DIM]

end = time()
logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')

return embedding
16 changes: 16 additions & 0 deletions graphiti_core/prompts/extract_edge_dates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion
Expand Down
Loading