diff --git a/agents-api/agents_api/exceptions.py b/agents-api/agents_api/exceptions.py index 615958a87..f6fcc4741 100644 --- a/agents-api/agents_api/exceptions.py +++ b/agents-api/agents_api/exceptions.py @@ -49,3 +49,12 @@ class FailedEncodingSentinel: """Sentinel object returned when failed to encode payload.""" payload_data: bytes + + +class QueriesBaseException(AgentsBaseException): + pass + + +class InvalidSQLQuery(QueriesBaseException): + def __init__(self, query_name: str): + super().__init__(f"invalid query: {query_name}") diff --git a/agents-api/agents_api/models/chat/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py deleted file mode 100644 index 368c88567..000000000 --- a/agents-api/agents_api/models/chat/get_cached_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def get_cached_response(key: str) -> tuple[str, dict]: - query = """ - input[key] <- [[$key]] - ?[key, value] := input[key], *session_cache{key, value} - :limit 1 - """ - - return (query, {"key": key}) diff --git a/agents-api/agents_api/models/chat/prepare_chat_context.py b/agents-api/agents_api/models/chat/prepare_chat_context.py deleted file mode 100644 index f77686d7a..000000000 --- a/agents-api/agents_api/models/chat/prepare_chat_context.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import ChatContext, make_session -from ..session.prepare_session_data import prepare_session_data -from ..utils import ( - cozo_query, - fix_uuid_if_present, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ChatContext, - one=True, - transform=lambda d: { - **d, - "session": make_session( - agents=[a["id"] for a in d["agents"]], - users=[u["id"] for u in d["users"]], - **d["session"], - ), - "toolsets": [ - { - **ts, - "tools": [ - { - tool["type"]: tool.pop("spec"), - **tool, - } - for tool in map(fix_uuid_if_present, ts["tools"]) - ], - } - for ts in d["toolsets"] - ], - }, -) -@cozo_query -@beartype -def prepare_chat_context( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """ - Executes a complex query to retrieve memory context based on session ID. - """ - - [*_, session_data_query], sd_vars = prepare_session_data.__wrapped__( - developer_id=developer_id, session_id=session_id - ) - - session_data_fields = ("session", "agents", "users") - - session_data_query += """ - :create _session_data_json { - agents: [Json], - users: [Json], - session: Json, - } - """ - - toolsets_query = """ - input[session_id] <- [[to_uuid($session_id)]] - - tools_by_agent[agent_id, collect(tool)] := - input[session_id], - *session_lookup{ - session_id, - participant_id: agent_id, - participant_type: "agent", - }, - - *tools { agent_id, tool_id, name, type, spec, description, updated_at, created_at }, - tool = { - "id": tool_id, - "name": name, - "type": type, - "spec": spec, - "description": description, - "updated_at": updated_at, - "created_at": created_at, - } - - agent_toolsets[collect(toolset)] := - tools_by_agent[agent_id, tools], - toolset = { - "agent_id": agent_id, - "tools": tools, - } - - ?[toolsets] := - agent_toolsets[toolsets] - - :create _toolsets_json { - toolsets: [Json], - } - """ - - combine_query = f""" - ?[{', '.join(session_data_fields)}, toolsets] := - *_session_data_json {{ {', '.join(session_data_fields)} }}, - *_toolsets_json {{ toolsets }} - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - session_data_query, - toolsets_query, - combine_query, - ] - - return ( - queries, - { - "session_id": str(session_id), - **sd_vars, - }, - ) diff --git a/agents-api/agents_api/models/chat/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py deleted file mode 100644 index 8625f3f1b..000000000 --- a/agents-api/agents_api/models/chat/set_cached_response.py +++ /dev/null @@ -1,19 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def set_cached_response(key: str, value: dict) -> tuple[str, dict]: - query = """ - ?[key, value] <- [[$key, $value]] - - :insert session_cache { - key => value - } - - :returning - """ - - return (query, {"key": key, "value": value}) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/queries/chat/__init__.py similarity index 92% rename from agents-api/agents_api/models/chat/__init__.py rename to agents-api/agents_api/queries/chat/__init__.py index 428b72572..2c05b4f8b 100644 --- a/agents-api/agents_api/models/chat/__init__.py +++ b/agents-api/agents_api/queries/chat/__init__.py @@ -17,6 +17,4 @@ # ruff: noqa: F401, F403, F405 from .gather_messages import gather_messages -from .get_cached_response import get_cached_response from .prepare_chat_context import prepare_chat_context -from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py similarity index 95% rename from agents-api/agents_api/models/chat/gather_messages.py rename to agents-api/agents_api/queries/chat/gather_messages.py index 28dc6607f..cbf3bf209 100644 --- a/agents-api/agents_api/models/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -3,7 +3,6 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from ...autogen.openapi_model import ChatInput, DocReference, History @@ -13,8 +12,8 @@ from ..docs.search_docs_by_embedding import search_docs_by_embedding from ..docs.search_docs_by_text import search_docs_by_text from ..docs.search_docs_hybrid import search_docs_hybrid -from ..entry.get_history import get_history -from ..session.get_session import get_session +from ..entries.get_history import get_history +from ..sessions.get_session import get_session from ..utils import ( partialclass, rewrap_exceptions, @@ -25,7 +24,6 @@ @rewrap_exceptions( { - QueryException: partialclass(HTTPException, status_code=400), ValidationError: partialclass(HTTPException, status_code=400), TypeError: partialclass(HTTPException, status_code=400), } diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py new file mode 100644 index 000000000..1d9bd52fb --- /dev/null +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -0,0 +1,169 @@ +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...common.protocol.sessions import ChatContext, make_session +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + + +sql_query = sqlvalidator.parse( + """SELECT * FROM +( + SELECT jsonb_agg(u) AS users FROM ( + SELECT + session_lookup.participant_id, + users.user_id AS id, + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata + FROM session_lookup + INNER JOIN users ON session_lookup.participant_id = users.user_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'user' + ) u +) AS users, +( + SELECT jsonb_agg(a) AS agents FROM ( + SELECT + session_lookup.participant_id, + agents.agent_id AS id, + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings + FROM session_lookup + INNER JOIN agents ON session_lookup.participant_id = agents.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) a +) AS agents, +( + SELECT to_jsonb(s) AS session FROM ( + SELECT + sessions.session_id AS id, + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options + FROM sessions + WHERE + developer_id = $1 AND + session_id = $2 + LIMIT 1 + ) s +) AS session, +( + SELECT jsonb_agg(r) AS toolsets FROM ( + SELECT + session_lookup.participant_id, + tools.tool_id as id, + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at + FROM session_lookup + INNER JOIN tools ON session_lookup.participant_id = tools.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) r +) AS toolsets""" +) +if not sql_query.is_valid(): + raise InvalidSQLQuery("prepare_chat_context") + + +def _transform(d): + toolsets = {} + for tool in d["toolsets"]: + agent_id = tool["agent_id"] + if agent_id in toolsets: + toolsets[agent_id].append(tool) + else: + toolsets[agent_id] = [tool] + + return { + **d, + "session": make_session( + agents=[a["id"] for a in d["agents"]], + users=[u["id"] for u in d["users"]], + **d["session"], + ), + "toolsets": [ + { + "agent_id": agent_id, + "tools": [ + { + tool["type"]: tool.pop("spec"), + **tool, + } + for tool in tools + ], + } + for agent_id, tools in toolsets.items() + ], + } + + +# TODO: implement this part +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ChatContext, + one=True, + transform=_transform, +) +@pg_query +@beartype +async def prepare_chat_context( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[list[str], list]: + """ + Executes a complex query to retrieve memory context based on session ID. + """ + + return ( + [sql_query.format()], + [developer_id, session_id], + )