Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models fixes #1126

Merged
merged 13 commits into from
Aug 28, 2023
8 changes: 4 additions & 4 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
updateExecution,
uploadFile,
getAgentDetails, addAgentRun, fetchModels,
getAgentWorkflows
getAgentWorkflows, validateOrAddModels
} from "@/pages/api/DashboardService";
import {
formatBytes,
Expand Down Expand Up @@ -56,7 +56,7 @@ export default function AgentCreate({
const [searchValue, setSearchValue] = useState('');
const [showButton, setShowButton] = useState(false);
const [showPlaceholder, setShowPlaceholder] = useState(true);
const [modelsArray, setModelsArray] = useState([]);
const [modelsArray, setModelsArray] = useState(['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']);

const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
Expand All @@ -69,7 +69,7 @@ export default function AgentCreate({
const [goals, setGoals] = useState(['Describe the agent goals here']);
const [instructions, setInstructions] = useState(['']);

const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']
const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']
const [model, setModel] = useState(models[1]);
const modelRef = useRef(null);
const [modelDropdown, setModelDropdown] = useState(false);
Expand Down Expand Up @@ -494,7 +494,7 @@ export default function AgentCreate({
return true;
}

const handleAddAgent = () => {
const handleAddAgent = async () => {
if (!validateAgentData(true)) {
return;
}
Expand Down
1 change: 0 additions & 1 deletion superagi/agent/agent_iteration_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
class AgentIterationStepHandler:
""" Handles iteration workflow steps in the agent workflow."""
def __init__(self, session, llm, agent_id: int, agent_execution_id: int, memory=None):
print(session, llm, agent_execution_id, agent_id, memory)
self.session = session
self.llm = llm
self.agent_execution_id = agent_execution_id
Expand Down
1 change: 0 additions & 1 deletion superagi/agent/agent_tool_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _process_input_instruction(self, agent_config, agent_execution_config, step_
prompt = self._build_tool_input_prompt(step_tool, tool_obj, agent_execution_config)
logger.info("Prompt: ", prompt)
agent_feeds = AgentExecutionFeed.fetch_agent_execution_feeds(self.session, self.agent_execution_id)
print(".........//////////////..........1")
messages = AgentLlmMessageBuilder(self.session, self.llm, self.llm.get_model(), self.agent_id, self.agent_execution_id) \
.build_agent_messages(prompt, agent_feeds, history_enabled=step_tool.history_enabled,
completion_prompt=step_tool.completion_prompt)
Expand Down
6 changes: 0 additions & 6 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def execute_next_step(self, agent_execution_id):
return

model_config = AgentConfiguration.get_model_api_key(session, agent_execution.agent_id, agent_config["model"])
print(model_config)
model_api_key = model_config['api_key']
model_llm_source = model_config['provider']
try:
Expand All @@ -72,8 +71,6 @@ def execute_next_step(self, agent_execution_id):
agent_workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.id == agent_execution.current_agent_step_id).first()
try:
print(agent_config["model"])
print(model_api_key)
if agent_workflow_step.action_type == "TOOL":
tool_step_handler = AgentToolStepHandler(session,
llm=get_model(model=agent_config["model"], api_key=model_api_key, organisation_id=organisation.id)
Expand Down Expand Up @@ -107,9 +104,6 @@ def execute_next_step(self, agent_execution_id):

@classmethod
def get_embedding(cls, model_source, model_api_key):
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
print(model_source)
print(model_api_key)
if "OpenAI" in model_source:
return OpenAiEmbedding(api_key=model_api_key)
if "Google" in model_source:
Expand Down
27 changes: 27 additions & 0 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.llms.openai import OpenAi
from superagi.helper.encyption_helper import decrypt_data
import requests, logging

# marketplace_url = "https://app.superagi.com/api"
Expand Down Expand Up @@ -153,6 +155,31 @@ def store_model_details(cls, session, organisation_id, model_name, description,
def fetch_models(cls, session, organisation_id) -> Union[Dict[str, str], List[Dict[str, Union[str, int]]]]:
try:
from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration

model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI",
ModelsConfig.org_id == organisation_id).first()
if model_provider is None:
configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key',
Configuration.organisation_id == organisation_id).first()

if configurations is None:
return {"error": "API Key is Missing"}
else:
default_models = {"gpt-3.5-turbo": 4032, "gpt-4": 8092, "gpt-3.5-turbo-16k": 16184}
model_api_key = decrypt_data(configurations.value)

model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key)
model_provider_id = model_details.get('model_provider_id')
models = OpenAi(api_key=model_api_key).get_models()

installed_models = [model[0] for model in session.query(Models.model_name).filter(Models.org_id == organisation_id).all()]

for model in models:
if model not in installed_models and model in default_models:
result = cls.store_model_details(session, organisation_id, model, model, '',
model_provider_id, default_models[model], 'Custom', '')

models = session.query(Models.id, Models.model_name, Models.description, ModelsConfig.provider).join(
ModelsConfig, Models.model_provider_id == ModelsConfig.id).filter(
Models.org_id == organisation_id).all()
Expand Down
9 changes: 6 additions & 3 deletions superagi/models/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ def store_api_key(cls, session, organisation_id, model_provider, model_api_key):
ModelsConfig.provider == model_provider)).first()
if existing_entry:
existing_entry.api_key = encrypt_data(model_api_key)
session.commit()
result = {'message': 'The API key was successfully updated'}
else:
new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider,
api_key=encrypt_data(model_api_key))
session.add(new_entry)
session.commit()
session.flush()
result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id}

session.commit()

return {'message': 'The API key was successfully stored'}
return result

@classmethod
def fetch_api_keys(cls, session, organisation_id):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_tests/models/test_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,16 @@ def test_fetch_model_by_id(mock_session):

# Call the method
model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id)
assert model == {"provider": "some_provider"}

def test_fetch_model_by_id_marketplace(mock_session):
# Arrange
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model

# Call the method
model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id)
assert model == {"provider": "some_provider"}