Skip to content

Commit

Permalink
fix(backend): Add migrations to fix credentials inputs with invalid p…
Browse files Browse the repository at this point in the history
…rovider "llm"(#8674)

In #8524, the "llm" credentials provider was replaced. There are still entries with 	"provider": "llm"	 in the system though, and those break if not migrated.

- SQL migration to fix the obvious ones where we know the provider from `credentials.id`
- Non-SQL migration to fix the rest
  • Loading branch information
Pwuts authored Nov 15, 2024
1 parent 0551bec commit 4db8e74
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
82 changes: 82 additions & 0 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timezone
from typing import Any, Literal, Type

import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
Expand Down Expand Up @@ -528,3 +529,84 @@ async def __create_graph(tx, graph: Graph, user_id: str):
for link in graph.links
]
)


# ------------------------ UTILITIES ------------------------ #


async def fix_llm_provider_credentials():
"""Fix node credentials with provider `llm`"""
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)

from .redis import get_redis
from .user import get_user_integrations

store = SupabaseIntegrationCredentialsStore(get_redis())

broken_nodes = await prisma.get_client().query_raw(
"""
SELECT user.id user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentGraph" graph
LEFT JOIN platform."AgentNode" node
ON node."agentGraphId" = graph.id
LEFT JOIN platform."User" user
ON graph."userId" = user.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY user_id;
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")

user_id: str = ""
user_integrations = None
for node in broken_nodes:
if node["user_id"] != user_id:
# Save queries by only fetching once per user
user_id = node["user_id"]
user_integrations = await get_user_integrations(user_id)
elif not user_integrations:
raise RuntimeError(f"Impossible state while processing node {node}")

node_id: str = node["node_id"]
node_preset_input: dict = json.loads(node["node_preset_input"])
credentials_meta: dict = node_preset_input["credentials"]

credentials = next(
(
c
for c in user_integrations.credentials
if c.id == credentials_meta["id"]
),
None,
)
if not credentials:
continue
if credentials.type != "api_key":
logger.warning(
f"User {user_id} credentials {credentials.id} with provider 'llm' "
f"has invalid type '{credentials.type}'"
)
continue

api_key = credentials.api_key.get_secret_value()
if api_key.startswith("sk-ant-api03-"):
credentials.provider = credentials_meta["provider"] = "anthropic"
elif api_key.startswith("sk-"):
credentials.provider = credentials_meta["provider"] = "openai"
elif api_key.startswith("gsk_"):
credentials.provider = credentials_meta["provider"] = "groq"
else:
logger.warning(
f"Could not identify provider from key prefix {api_key[:13]}*****"
)
continue

store.update_creds(user_id, credentials)
await AgentNode.prisma().update(
where={"id": node_id},
data={"constantInput": json.dumps(node_preset_input)},
)
2 changes: 2 additions & 0 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.util.service
Expand All @@ -23,6 +24,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
yield
await backend.data.db.disconnect()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Correct credentials.provider field on all nodes with 'llm' provider credentials
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{credentials,provider}',
CASE
WHEN "constantInput"::jsonb->'credentials'->>'id' = '53c25cb8-e3ee-465c-a4d1-e75a4c899c2a' THEN '"openai"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '24e5d942-d9e3-4798-8151-90143ee55629' THEN '"anthropic"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '4ec22295-8f97-4dd1-b42b-2c6957a02545' THEN '"groq"'::jsonb
ELSE ("constantInput"::jsonb->'credentials'->>'provider')::jsonb
END
)::text
WHERE "constantInput"::jsonb->'credentials'->>'provider' = 'llm';

0 comments on commit 4db8e74

Please sign in to comment.