diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index a75fd198..343f7d4a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -66,6 +66,42 @@ class Graphiti: def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): + """ + Initialize a Graphiti instance. + + This constructor sets up a connection to the Neo4j database and initializes + the LLM client for natural language processing tasks. + + Parameters + ---------- + uri : str + The URI of the Neo4j database. + user : str + The username for authenticating with the Neo4j database. + password : str + The password for authenticating with the Neo4j database. + llm_client : LLMClient | None, optional + An instance of LLMClient for natural language processing tasks. + If not provided, a default OpenAIClient will be initialized. + + Returns + ------- + None + + Notes + ----- + This method establishes a connection to the Neo4j database using the provided + credentials. It also sets up the LLM client, either using the provided client + or by creating a default OpenAIClient. + + The default database name is set to 'neo4j'. If a different database name + is required, it should be specified in the URI or set separately after + initialization. + + The OpenAI API key is expected to be set in the environment variables. + Make sure to set the OPENAI_API_KEY environment variable before initializing + Graphiti if you're using the default OpenAIClient. + """ self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.database = 'neo4j' if llm_client: @@ -79,9 +115,67 @@ def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | N ) def close(self): + """ + Close the connection to the Neo4j database. + + This method safely closes the driver connection to the Neo4j database. + It should be called when the Graphiti instance is no longer needed or + when the application is shutting down. + + Parameters + ---------- + None + + Returns + ------- + None + + Notes + ----- + It's important to close the driver connection to release system resources + and ensure that all pending transactions are completed or rolled back. + This method should be called as part of a cleanup process, potentially + in a context manager or a shutdown hook. + + Example: + graphiti = Graphiti(uri, user, password) + try: + # Use graphiti... + finally: + graphiti.close() self.driver.close() + """ async def build_indices_and_constraints(self): + """ + Build indices and constraints in the Neo4j database. + + This method sets up the necessary indices and constraints in the Neo4j database + to optimize query performance and ensure data integrity for the knowledge graph. + + Parameters + ---------- + None + + Returns + ------- + None + + Notes + ----- + This method should typically be called once during the initial setup of the + knowledge graph or when updating the database schema. It uses the + `build_indices_and_constraints` function from the + `graphiti_core.utils.maintenance.graph_data_operations` module to perform + the actual database operations. + + The specific indices and constraints created depend on the implementation + of the `build_indices_and_constraints` function. Refer to that function's + documentation for details on the exact database schema modifications. + + Caution: Running this method on a large existing database may take some time + and could impact database performance during execution. + """ await build_indices_and_constraints(self.driver) async def retrieve_episodes( @@ -89,7 +183,29 @@ async def retrieve_episodes( reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" + """ + Retrieve the last n episodic nodes from the graph. + + This method fetches a specified number of the most recent episodic nodes + from the graph, relative to the given reference time. + + Parameters + ---------- + reference_time : datetime + The reference time to retrieve episodes before. + last_n : int, optional + The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN. + + Returns + ------- + list[EpisodicNode] + A list of the most recent EpisodicNode objects. + + Notes + ----- + The actual retrieval is performed by the `retrieve_episodes` function + from the `graphiti_core.utils` module. + """ return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( @@ -102,7 +218,50 @@ async def add_episode( success_callback: Callable | None = None, error_callback: Callable | None = None, ): - """Process an episode and update the graph""" + """ + Process an episode and update the graph. + + This method extracts information from the episode, creates nodes and edges, + and updates the graph database accordingly. + + Parameters + ---------- + name : str + The name of the episode. + episode_body : str + The content of the episode. + source_description : str + A description of the episode's source. + reference_time : datetime + The reference time for the episode. + source : EpisodeType, optional + The type of the episode. Defaults to EpisodeType.message. + success_callback : Callable | None, optional + A callback function to be called upon successful processing. + error_callback : Callable | None, optional + A callback function to be called if an error occurs during processing. + + Returns + ------- + None + + Notes + ----- + This method performs several steps including node extraction, edge extraction, + deduplication, and database updates. It also handles embedding generation + and edge invalidation. + + It is recommended to run this method as a background process, such as in a queue. + It's important that each episode is added sequentially and awaited before adding + the next one. For web applications, consider using FastAPI's background tasks + or a dedicated task queue like Celery for this purpose. + + Example using FastAPI background tasks: + @app.post("/add_episode") + async def add_episode_endpoint(episode_data: EpisodeData): + background_tasks.add_task(graphiti.add_episode, **episode_data.dict()) + return {"message": "Episode processing started"} + """ try: start = time() @@ -255,6 +414,40 @@ async def add_episode_bulk( self, bulk_episodes: list[RawEpisode], ): + """ + Process multiple episodes in bulk and update the graph. + + This method extracts information from multiple episodes, creates nodes and edges, + and updates the graph database accordingly, all in a single batch operation. + + Parameters + ---------- + bulk_episodes : list[RawEpisode] + A list of RawEpisode objects to be processed and added to the graph. + + Returns + ------- + None + + Notes + ----- + This method performs several steps including: + - Saving all episodes to the database + - Retrieving previous episode context for each new episode + - Extracting nodes and edges from all episodes + - Generating embeddings for nodes and edges + - Deduplicating nodes and edges + - Saving nodes, episodic edges, and entity edges to the knowledge graph + + This bulk operation is designed for efficiency when processing multiple episodes + at once. However, it's important to ensure that the bulk operation doesn't + overwhelm system resources. Consider implementing rate limiting or chunking for + very large batches of episodes. + + Important: This method does not perform edge invalidation or date extraction steps. + If these operations are required, use the `add_episode` method instead for each + individual episode. + """ try: start = time() embedder = self.llm_client.get_embedder() @@ -329,6 +522,33 @@ async def add_episode_bulk( raise e async def search(self, query: str, num_results=10): + """ + Perform a hybrid search on the knowledge graph. + + This method executes a search query on the graph, combining vector and + text-based search techniques to retrieve relevant facts. + + Parameters + ---------- + query : str + The search query string. + num_results : int, optional + The maximum number of results to return. Defaults to 10. + + Returns + ------- + list + A list of facts (strings) that are relevant to the search query. + + Notes + ----- + This method uses a SearchConfig with num_episodes set to 0 and + num_results set to the provided num_results parameter. It then calls + the hybrid_search function to perform the actual search operation. + + The search is performed using the current date and time as the reference + point for temporal relevance. + """ search_config = SearchConfig(num_episodes=0, num_results=num_results) edges = ( await hybrid_search( diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 72ba0249..6a4df2bb 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -31,6 +31,23 @@ class EpisodeType(Enum): + """ + Enumeration of different types of episodes that can be processed. + + This enum defines the various sources or formats of episodes that the system + can handle. It's used to categorize and potentially handle different types + of input data differently. + + Attributes: + ----------- + message : str + Represents a standard message-type episode. The content for this type + should be formatted as "actor: content". For example, "user: Hello, how are you?" + or "assistant: I'm doing well, thank you for asking." + json : str + Represents an episode containing a JSON string object with structured data. + """ + message = 'message' json = 'json' diff --git a/tests/tests_int_graphiti.py b/tests/test_graphiti_int.py similarity index 85% rename from tests/tests_int_graphiti.py rename to tests/test_graphiti_int.py index 410f72d8..2ab6c8f0 100644 --- a/tests/tests_int_graphiti.py +++ b/tests/test_graphiti_int.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import asyncio import logging import os diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index dcee3573..76224bc5 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import unittest from datetime import datetime, timedelta diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 16591ef1..9e6b2953 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import os from datetime import datetime, timedelta