From 737268b557e8658992d4192324987437f15ff363 Mon Sep 17 00:00:00 2001 From: Kalki <97698934+jedan2506@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:58:18 +0530 Subject: [PATCH] Models fixes (#1126) * Models Frontend Changes * Models Frontend Changes * Models Frontend Changes * Backend Compatibility for New/Existing users on local * DEV api key requirements * removing print statements * removing print statements * removing print statements * removing print statements * backend compatibility * backend compatibility * backend compatibility --- gui/pages/Content/Agents/AgentCreate.js | 8 +++--- .../agent/agent_iteration_step_handler.py | 1 - superagi/agent/agent_tool_step_handler.py | 1 - superagi/jobs/agent_executor.py | 6 ----- superagi/models/models.py | 27 +++++++++++++++++++ superagi/models/models_config.py | 9 ++++--- tests/unit_tests/models/test_models_config.py | 12 +++++++++ 7 files changed, 49 insertions(+), 15 deletions(-) diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 95f731786..4d2849042 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -11,7 +11,7 @@ import { updateExecution, uploadFile, getAgentDetails, addAgentRun, fetchModels, - getAgentWorkflows + getAgentWorkflows, validateOrAddModels } from "@/pages/api/DashboardService"; import { formatBytes, @@ -56,7 +56,7 @@ export default function AgentCreate({ const [searchValue, setSearchValue] = useState(''); const [showButton, setShowButton] = useState(false); const [showPlaceholder, setShowPlaceholder] = useState(true); - const [modelsArray, setModelsArray] = useState([]); + const [modelsArray, setModelsArray] = useState(['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']); const constraintsArray = [ "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.", @@ -69,7 +69,7 @@ export default function AgentCreate({ const [goals, setGoals] = useState(['Describe the agent goals here']); const [instructions, setInstructions] = useState(['']); - const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat'] + const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k'] const [model, setModel] = useState(models[1]); const modelRef = useRef(null); const [modelDropdown, setModelDropdown] = useState(false); @@ -494,7 +494,7 @@ export default function AgentCreate({ return true; } - const handleAddAgent = () => { + const handleAddAgent = async () => { if (!validateAgentData(true)) { return; } diff --git a/superagi/agent/agent_iteration_step_handler.py b/superagi/agent/agent_iteration_step_handler.py index 94e50c3e9..543a285e6 100644 --- a/superagi/agent/agent_iteration_step_handler.py +++ b/superagi/agent/agent_iteration_step_handler.py @@ -34,7 +34,6 @@ class AgentIterationStepHandler: """ Handles iteration workflow steps in the agent workflow.""" def __init__(self, session, llm, agent_id: int, agent_execution_id: int, memory=None): - print(session, llm, agent_execution_id, agent_id, memory) self.session = session self.llm = llm self.agent_execution_id = agent_execution_id diff --git a/superagi/agent/agent_tool_step_handler.py b/superagi/agent/agent_tool_step_handler.py index 52951a00d..7aeb0d59b 100644 --- a/superagi/agent/agent_tool_step_handler.py +++ b/superagi/agent/agent_tool_step_handler.py @@ -99,7 +99,6 @@ def _process_input_instruction(self, agent_config, agent_execution_config, step_ prompt = self._build_tool_input_prompt(step_tool, tool_obj, agent_execution_config) logger.info("Prompt: ", prompt) agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id) - print(".........//////////////..........1") messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \ .build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled, completion_prompt=step_tool.completion_prompt) diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index faf1ccc48..8ff0977b2 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -56,7 +56,6 @@ def execute_next_step(self, agent_execution_id): return model_config = AgentConfiguration.get_model_api_key(session, agent_execution.agent_id, agent_config["model"]) - print(model_config) model_api_key = model_config['api_key'] model_llm_source = model_config['provider'] try: @@ -72,8 +71,6 @@ def execute_next_step(self, agent_execution_id): agent_workflow_step = session.query(AgentWorkflowStep).filter( AgentWorkflowStep.id == agent_execution.current_agent_step_id).first() try: - print(agent_config["model"]) - print(model_api_key) if agent_workflow_step.action_type == "TOOL": tool_step_handler = AgentToolStepHandler(session, llm=get_model(model=agent_config["model"], api_key=model_api_key, organisation_id=organisation.id) @@ -107,9 +104,6 @@ def execute_next_step(self, agent_execution_id): @classmethod def get_embedding(cls, model_source, model_api_key): - print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") - print(model_source) - print(model_api_key) if "OpenAI" in model_source: return OpenAiEmbedding(api_key=model_api_key) if "Google" in model_source: diff --git a/superagi/models/models.py b/superagi/models/models.py index d024c842f..35c0e9303 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -2,6 +2,8 @@ from sqlalchemy.sql import func from typing import List, Dict, Union from superagi.models.base_model import DBBaseModel +from superagi.llms.openai import OpenAi +from superagi.helper.encyption_helper import decrypt_data import requests, logging # marketplace_url = "https://app.superagi.com/api" @@ -153,6 +155,31 @@ def store_model_details(cls, session, organisation_id, model_name, description, def fetch_models(cls, session, organisation_id) -> Union[Dict[str, str], List[Dict[str, Union[str, int]]]]: try: from superagi.models.models_config import ModelsConfig + from superagi.models.configuration import Configuration + + model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI", + ModelsConfig.org_id == organisation_id).first() + if model_provider is None: + configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key', + Configuration.organisation_id == organisation_id).first() + + if configurations is None: + return {"error": "API Key is Missing"} + else: + default_models = {"gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-3.5-turbo-16k": 16184} + model_api_key = decrypt_data(configurations.value) + + model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key) + model_provider_id = model_details.get('model_provider_id') + models = OpenAi(api_key=model_api_key).get_models() + + installed_models = [model[0] for model in session.query(Models.model_name).filter(Models.org_id == organisation_id).all()] + + for model in models: + if model not in installed_models and model in default_models: + result = cls.store_model_details(session, organisation_id, model, model, '', + model_provider_id, default_models[model], 'Custom', '') + models = session.query(Models.id, Models.model_name, Models.description, ModelsConfig.provider).join( ModelsConfig, Models.model_provider_id == ModelsConfig.id).filter( Models.org_id == organisation_id).all() diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index a8d2a2ba6..fc4bad9e3 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -76,14 +76,17 @@ def store_api_key(cls, session, organisation_id, model_provider, model_api_key): ModelsConfig.provider == model_provider)).first() if existing_entry: existing_entry.api_key = encrypt_data(model_api_key) + session.commit() + result = {'message': 'The API key was successfully updated'} else: new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider, api_key=encrypt_data(model_api_key)) session.add(new_entry) + session.commit() + session.flush() + result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id} - session.commit() - - return {'message': 'The API key was successfully stored'} + return result @classmethod def fetch_api_keys(cls, session, organisation_id): diff --git a/tests/unit_tests/models/test_models_config.py b/tests/unit_tests/models/test_models_config.py index 56bfd89c0..5c1731d3a 100644 --- a/tests/unit_tests/models/test_models_config.py +++ b/tests/unit_tests/models/test_models_config.py @@ -111,4 +111,16 @@ def test_fetch_model_by_id(mock_session): # Call the method model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id) + assert model == {"provider": "some_provider"} + +def test_fetch_model_by_id_marketplace(mock_session): + # Arrange + model_provider_id = 1 + # Mock model + mock_model = MagicMock() + mock_model.provider = 'some_provider' + mock_session.query.return_value.filter.return_value.first.return_value = mock_model + + # Call the method + model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id) assert model == {"provider": "some_provider"} \ No newline at end of file