Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat queries #976

Draft
wants to merge 6 commits into
base: f/switch-to-pg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions agents-api/agents_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ class FailedEncodingSentinel:
"""Sentinel object returned when failed to encode payload."""

payload_data: bytes


class QueriesBaseException(AgentsBaseException):
pass


class InvalidSQLQuery(QueriesBaseException):
def __init__(self, query_name: str):
super().__init__(f"invalid query: {query_name}")
15 changes: 0 additions & 15 deletions agents-api/agents_api/models/chat/get_cached_response.py

This file was deleted.

143 changes: 0 additions & 143 deletions agents-api/agents_api/models/chat/prepare_chat_context.py

This file was deleted.

19 changes: 0 additions & 19 deletions agents-api/agents_api/models/chat/set_cached_response.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,4 @@
# ruff: noqa: F401, F403, F405

from .gather_messages import gather_messages
from .get_cached_response import get_cached_response
from .prepare_chat_context import prepare_chat_context
from .set_cached_response import set_cached_response
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import ChatInput, DocReference, History
Expand All @@ -13,8 +12,8 @@
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
from ..entry.get_history import get_history
from ..session.get_session import get_session
from ..entries.get_history import get_history
from ..sessions.get_session import get_session
from ..utils import (
partialclass,
rewrap_exceptions,
Expand All @@ -25,7 +24,6 @@

@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
Expand Down
169 changes: 169 additions & 0 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import Any, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype

from ...common.protocol.sessions import ChatContext, make_session
from ...exceptions import InvalidSQLQuery
from ..utils import (
pg_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


sql_query = sqlvalidator.parse(
"""SELECT * FROM
(
SELECT jsonb_agg(u) AS users FROM (
SELECT
session_lookup.participant_id,
users.user_id AS id,
users.developer_id,
users.name,
users.about,
users.created_at,
users.updated_at,
users.metadata
FROM session_lookup
INNER JOIN users ON session_lookup.participant_id = users.user_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'user'
) u
) AS users,
(
SELECT jsonb_agg(a) AS agents FROM (
SELECT
session_lookup.participant_id,
agents.agent_id AS id,
agents.developer_id,
agents.canonical_name,
agents.name,
agents.about,
agents.instructions,
agents.model,
agents.created_at,
agents.updated_at,
agents.metadata,
agents.default_settings
FROM session_lookup
INNER JOIN agents ON session_lookup.participant_id = agents.agent_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'agent'
) a
) AS agents,
(
SELECT to_jsonb(s) AS session FROM (
SELECT
sessions.session_id AS id,
sessions.developer_id,
sessions.situation,
sessions.system_template,
sessions.created_at,
sessions.metadata,
sessions.render_templates,
sessions.token_budget,
sessions.context_overflow,
sessions.forward_tool_calls,
sessions.recall_options
FROM sessions
WHERE
developer_id = $1 AND
session_id = $2
LIMIT 1
) s
) AS session,
(
SELECT jsonb_agg(r) AS toolsets FROM (
SELECT
session_lookup.participant_id,
tools.tool_id as id,
tools.developer_id,
tools.agent_id,
tools.task_id,
tools.task_version,
tools.type,
tools.name,
tools.description,
tools.spec,
tools.updated_at,
tools.created_at
FROM session_lookup
INNER JOIN tools ON session_lookup.participant_id = tools.agent_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'agent'
) r
) AS toolsets"""
)
if not sql_query.is_valid():
raise InvalidSQLQuery("prepare_chat_context")


def _transform(d):
toolsets = {}
for tool in d["toolsets"]:
agent_id = tool["agent_id"]
if agent_id in toolsets:
toolsets[agent_id].append(tool)
else:
toolsets[agent_id] = [tool]

return {
**d,
"session": make_session(
agents=[a["id"] for a in d["agents"]],
users=[u["id"] for u in d["users"]],
**d["session"],
),
"toolsets": [
{
"agent_id": agent_id,
"tools": [
{
tool["type"]: tool.pop("spec"),
**tool,
}
for tool in tools
],
}
for agent_id, tools in toolsets.items()
],
}


# TODO: implement this part
# @rewrap_exceptions(
# {
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
ChatContext,
one=True,
transform=_transform,
)
@pg_query
@beartype
async def prepare_chat_context(
*,
developer_id: UUID,
session_id: UUID,
) -> tuple[list[str], list]:
"""
Executes a complex query to retrieve memory context based on session ID.
"""

return (
[sql_query.format()],
[developer_id, session_id],
)