Skip to content

Commit

Permalink
Cleanup maintenance utilities + add podcast runner (#5)
Browse files Browse the repository at this point in the history
* chore: Fix minor issues with episodic edge building + cleanup

* feat: Port podcast runner

* feat: Port podcast runner
  • Loading branch information
paul-paliychuk authored Aug 16, 2024
1 parent f1c2224 commit ad552b5
Show file tree
Hide file tree
Showing 10 changed files with 679 additions and 36 deletions.
12 changes: 10 additions & 2 deletions core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class Edge(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...

def __hash__(self):
return hash(self.uuid)

def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False


class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
Expand Down Expand Up @@ -58,7 +66,8 @@ class EntityEdge(Edge):
default=None, description="embedding of the fact"
)
episodes: list[str] | None = Field(
default=None, description="list of episode ids that reference these entity edges"
default=None,
description="list of episode ids that reference these entity edges",
)
expired_at: datetime | None = Field(
default=None, description="datetime of when the node was invalidated"
Expand All @@ -79,7 +88,6 @@ def generate_embedding(self, embedder, model="text-embedding-3-large"):

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(

"""
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
Expand Down
14 changes: 11 additions & 3 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,25 @@ async def add_episode(
created_at=datetime.now(),
valid_at=reference_time,
)
await episode.save(self.driver)
# await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await extract_new_edges(
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
episodic_edges = build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
list(set(nodes + affected_nodes)),
episode,
datetime.now(),
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)

# invalidated_edges = await self.invalidate_edges(
Expand Down
6 changes: 5 additions & 1 deletion core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from .client import LLMClient
from .config import LLMConfig

import logging

logger = logging.getLogger(__name__)


class OpenAIClient(LLMClient):
def __init__(self, config: LLMConfig):
Expand All @@ -20,5 +24,5 @@ async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, a
)
return json.loads(response.choices[0].message.content)
except Exception as e:
print(f"Error in generating LLM response: {e}")
logger.error(f"Error in generating LLM response: {e}")
raise
12 changes: 10 additions & 2 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class Node(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...

def __hash__(self):
return hash(self.uuid)

def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False


class EpisodicNode(Node):
source: str = Field(description="source type")
Expand All @@ -38,7 +46,7 @@ async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
Expand All @@ -48,11 +56,11 @@ async def save(self, driver: AsyncDriver):
entity_edges=self.entity_edges,
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source,
_database="neo4j",
)

logger.info(f"Saved Node to neo4j: {self.uuid}")
print(self.uuid)

return result

Expand Down
10 changes: 8 additions & 2 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def extract_new_edges(
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
) -> tuple[list[EntityEdge], list[EntityNode]]:
# Prepare context for LLM
context = {
"episode_content": episode.content,
Expand All @@ -58,6 +58,7 @@ async def extract_new_edges(
prompt_library.extract_edges.v1(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")

# Convert the extracted data into EntityEdge objects
new_edges = []
Expand Down Expand Up @@ -125,4 +126,9 @@ async def extract_new_edges(
f"Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})"
)

return new_edges
affected_nodes = set()

for edge in new_edges:
affected_nodes.add(edge.source_node)
affected_nodes.add(edge.target_node)
return new_edges, list(affected_nodes)
61 changes: 35 additions & 26 deletions core/utils/maintenance/graph_data_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,38 @@ async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
async with driver.session() as session:
query = """
MATCH (e:EpisodicNode)
RETURN e.content as text, e.timestamp as timestamp, e.reference_timestamp as reference_timestamp
ORDER BY e.timestamp DESC
LIMIT $num_episodes
"""
result = await session.run(query, num_episodes=last_n)
episodes = [
EpisodicNode(
content=record["text"],
transaction_from=datetime.fromtimestamp(
record["timestamp"].to_native().timestamp(), timezone.utc
),
valid_at=(
datetime.fromtimestamp(
record["reference_timestamp"].to_native().timestamp(),
timezone.utc,
)
if record["reference_timestamp"] is not None
else None
),
)
async for record in result
]
return list(reversed(episodes)) # Return in chronological order
query = """
MATCH (e:Episodic)
RETURN e.content as content,
e.created_at as created_at,
e.valid_at as valid_at,
e.uuid as uuid,
e.name as name,
e.source_description as source_description,
e.source as source
ORDER BY e.created_at DESC
LIMIT $num_episodes
"""
result = await driver.execute_query(query, num_episodes=last_n)
episodes = [
EpisodicNode(
content=record["content"],
created_at=datetime.fromtimestamp(
record["created_at"].to_native().timestamp(), timezone.utc
),
valid_at=(
datetime.fromtimestamp(
record["valid_at"].to_native().timestamp(),
timezone.utc,
)
if record["valid_at"] is not None
else None
),
uuid=record["uuid"],
source=record["source"],
name=record["name"],
source_description=record["source_description"],
)
for record in result.records
]
return list(reversed(episodes)) # Return in chronological order
54 changes: 54 additions & 0 deletions podcast_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from core import Graphiti
from core.utils.maintenance.graph_data_operations import clear_data
from dotenv import load_dotenv
import os
import asyncio
import logging
import sys
from transcript_parser import parse_podcast_messages

load_dotenv()

neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687"
neo4j_user = os.environ.get("NEO4J_USER") or "neo4j"
neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password"


def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO

# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

# Add formatter to console handler
console_handler.setFormatter(formatter)

# Add console handler to logger
logger.addHandler(console_handler)

return logger


async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
messages = parse_podcast_messages()
for i, message in enumerate(messages[3:14]):
await client.add_episode(
name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
reference_time=message.actual_timestamp,
source_description="Podcast Transcript",
)


asyncio.run(main())
Loading

0 comments on commit ad552b5

Please sign in to comment.