Skip to content

Commit

Permalink
TLK-1771 - Improve agent creation flow
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneLightsOn committed Oct 16, 2024
1 parent 3d23402 commit c460dde
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/backend/tests/unit/routers/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import ToolName
from backend.crud import agent as agent_crud
from backend.crud import deployment as deployment_crud
from backend.crud import model as model_crud
from backend.database_models.agent import Agent
from backend.database_models.agent_tool_metadata import AgentToolMetadata
from backend.database_models.snapshot import Snapshot
from backend.tests.unit.factories import get_factory

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
and os.environ.get("COHERE_API_KEY") != ""
)

def test_create_agent_missing_name(
session_client: TestClient, session: Session, user
Expand Down Expand Up @@ -98,6 +103,7 @@ def test_create_agent_invalid_deployment(
}


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_create_agent_deployment_not_in_db(
session_client: TestClient, session: Session, user
) -> None:
Expand All @@ -110,18 +116,16 @@ def test_create_agent_deployment_not_in_db(
"deployment": ModelDeploymentName.CoherePlatform,
}
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
assert cohere_deployment
deployment_crud.delete_deployment(session, cohere_deployment.id)
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
assert not cohere_deployment
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": user.id}
)
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
model_command_r_plus = model_crud.get_model_by_name(session, "command-r-plus")
deployment_models = cohere_deployment.models
deployment_models_list = [model.name for model in deployment_models]
assert response.status_code == 200
assert cohere_deployment
assert model_command_r_plus
assert "command-r-plus" in deployment_models_list


def test_create_agent_invalid_tool(
Expand Down

0 comments on commit c460dde

Please sign in to comment.