Skip to content

Commit

Permalink
Models fixes (#1126)
Browse files Browse the repository at this point in the history
* Models Frontend Changes

* Models Frontend Changes

* Models Frontend Changes

* Backend Compatibility for New/Existing users on local

* DEV api key requirements

* removing print statements

* removing print statements

* removing print statements

* removing print statements

* backend compatibility

* backend compatibility

* backend compatibility
  • Loading branch information
jedan2506 committed Aug 28, 2023
1 parent 50dbec2 commit 737268b
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 15 deletions.
8 changes: 4 additions & 4 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
updateExecution,
uploadFile,
getAgentDetails, addAgentRun, fetchModels,
getAgentWorkflows
getAgentWorkflows, validateOrAddModels
} from "@/pages/api/DashboardService";
import {
formatBytes,
Expand Down Expand Up @@ -56,7 +56,7 @@ export default function AgentCreate({
const [searchValue, setSearchValue] = useState('');
const [showButton, setShowButton] = useState(false);
const [showPlaceholder, setShowPlaceholder] = useState(true);
const [modelsArray, setModelsArray] = useState([]);
const [modelsArray, setModelsArray] = useState(['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']);

const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
Expand All @@ -69,7 +69,7 @@ export default function AgentCreate({
const [goals, setGoals] = useState(['Describe the agent goals here']);
const [instructions, setInstructions] = useState(['']);

const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']
const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']
const [model, setModel] = useState(models[1]);
const modelRef = useRef(null);
const [modelDropdown, setModelDropdown] = useState(false);
Expand Down Expand Up @@ -494,7 +494,7 @@ export default function AgentCreate({
return true;
}

const handleAddAgent = () => {
const handleAddAgent = async () => {
if (!validateAgentData(true)) {
return;
}
Expand Down
1 change: 0 additions & 1 deletion superagi/agent/agent_iteration_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
class AgentIterationStepHandler:
""" Handles iteration workflow steps in the agent workflow."""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int, memory=None):
print(session, llm, agent_execution_id, agent_id, memory)
self.session = session
self.llm = llm
self.agent_execution_id = agent_execution_id
Expand Down
1 change: 0 additions & 1 deletion superagi/agent/agent_tool_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _process_input_instruction(self, agent_config, agent_execution_config, step_
prompt = self._build_tool_input_prompt(step_tool, tool_obj, agent_execution_config)
logger.info("Prompt: ", prompt)
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
print(".........//////////////..........1")
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled,
completion_prompt=step_tool.completion_prompt)
Expand Down
6 changes: 0 additions & 6 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def execute_next_step(self, agent_execution_id):
return

model_config = AgentConfiguration.get_model_api_key(session, agent_execution.agent_id, agent_config["model"])
print(model_config)
model_api_key = model_config['api_key']
model_llm_source = model_config['provider']
try:
Expand All @@ -72,8 +71,6 @@ def execute_next_step(self, agent_execution_id):
agent_workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.id == agent_execution.current_agent_step_id).first()
try:
print(agent_config["model"])
print(model_api_key)
if agent_workflow_step.action_type == "TOOL":
tool_step_handler = AgentToolStepHandler(session,
llm=get_model(model=agent_config["model"], api_key=model_api_key, organisation_id=organisation.id)
Expand Down Expand Up @@ -107,9 +104,6 @@ def execute_next_step(self, agent_execution_id):

@classmethod
def get_embedding(cls, model_source, model_api_key):
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
print(model_source)
print(model_api_key)
if "OpenAI" in model_source:
return OpenAiEmbedding(api_key=model_api_key)
if "Google" in model_source:
Expand Down
27 changes: 27 additions & 0 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.llms.openai import OpenAi
from superagi.helper.encyption_helper import decrypt_data
import requests, logging

# marketplace_url = "https://app.superagi.com/api"
Expand Down Expand Up @@ -153,6 +155,31 @@ def store_model_details(cls, session, organisation_id, model_name, description,
def fetch_models(cls, session, organisation_id) -> Union[Dict[str, str], List[Dict[str, Union[str, int]]]]:
try:
from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration

model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI",
ModelsConfig.org_id == organisation_id).first()
if model_provider is None:
configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key',
Configuration.organisation_id == organisation_id).first()

if configurations is None:
return {"error": "API Key is Missing"}
else:
default_models = {"gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-3.5-turbo-16k": 16184}
model_api_key = decrypt_data(configurations.value)

model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key)
model_provider_id = model_details.get('model_provider_id')
models = OpenAi(api_key=model_api_key).get_models()

installed_models = [model[0] for model in session.query(Models.model_name).filter(Models.org_id == organisation_id).all()]

for model in models:
if model not in installed_models and model in default_models:
result = cls.store_model_details(session, organisation_id, model, model, '',
model_provider_id, default_models[model], 'Custom', '')

models = session.query(Models.id, Models.model_name, Models.description, ModelsConfig.provider).join(
ModelsConfig, Models.model_provider_id == ModelsConfig.id).filter(
Models.org_id == organisation_id).all()
Expand Down
9 changes: 6 additions & 3 deletions superagi/models/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ def store_api_key(cls, session, organisation_id, model_provider, model_api_key):
ModelsConfig.provider == model_provider)).first()
if existing_entry:
existing_entry.api_key = encrypt_data(model_api_key)
session.commit()
result = {'message': 'The API key was successfully updated'}
else:
new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider,
api_key=encrypt_data(model_api_key))
session.add(new_entry)
session.commit()
session.flush()
result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id}

session.commit()

return {'message': 'The API key was successfully stored'}
return result

@classmethod
def fetch_api_keys(cls, session, organisation_id):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_tests/models/test_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,16 @@ def test_fetch_model_by_id(mock_session):

# Call the method
model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id)
assert model == {"provider": "some_provider"}

def test_fetch_model_by_id_marketplace(mock_session):
# Arrange
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model

# Call the method
model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id)
assert model == {"provider": "some_provider"}

0 comments on commit 737268b

Please sign in to comment.