Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models fixes #1126

Merged
merged 13 commits into from
Aug 28, 2023
24 changes: 19 additions & 5 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
updateExecution,
uploadFile,
getAgentDetails, addAgentRun, fetchModels,
getAgentWorkflows
getAgentWorkflows, validateOrAddModels
} from "@/pages/api/DashboardService";
import {
formatBytes,
Expand Down Expand Up @@ -56,7 +56,7 @@ export default function AgentCreate({
const [searchValue, setSearchValue] = useState('');
const [showButton, setShowButton] = useState(false);
const [showPlaceholder, setShowPlaceholder] = useState(true);
const [modelsArray, setModelsArray] = useState([]);
const [modelsArray, setModelsArray] = useState(['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']);

const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
Expand All @@ -69,7 +69,7 @@ export default function AgentCreate({
const [goals, setGoals] = useState(['Describe the agent goals here']);
const [instructions, setInstructions] = useState(['']);

const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']
const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
const [model, setModel] = useState(models[1]);
const modelRef = useRef(null);
const [modelDropdown, setModelDropdown] = useState(false);
Expand Down Expand Up @@ -155,7 +155,7 @@ export default function AgentCreate({
.then((response) => {
const models = response.data.map(model => model.name) || [];
const selected_model = localStorage.getItem("agent_model_" + String(internalId)) || '';
setModelsArray(models);
setModelsArray(prevModels => Array.from(new Set([...prevModels, ...models])));
if (models.length > 0 && !selected_model) {
setLocalStorageValue("agent_model_" + String(internalId), models[0], setModel);
} else {
Expand Down Expand Up @@ -494,7 +494,21 @@ export default function AgentCreate({
return true;
}

const handleAddAgent = () => {
const validateModel = async () => {
const response = await validateOrAddModels(model)
if (response.data.error) {
toast.error(response.data.error, {autoClose: 1800});
return false;
}
return true;
}

const handleAddAgent = async () => {
if(env === 'DEV') {
const bool = await validateModel()
if(!bool) return;
}

if (!validateAgentData(true)) {
return;
}
Expand Down
5 changes: 5 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,8 @@ export const fetchMarketPlaceModel = () => {
return api.get(`/models_controller/get/list`)
}

export const validateOrAddModels = (model) => {
return api.get(`/models_controller/validate_or_add_gpt_models`, {
params: { model }
});
}
9 changes: 9 additions & 0 deletions superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@
raise HTTPException(status_code=500, detail="Internal Server Error")


@router.get("/validate_or_add_gpt_models", status_code=200)
async def validate_or_add_gpt_models(model: str = None, organisation=Depends(get_user_organisation)):
try:
return Models.validate_model_in_db(db.session, organisation.id, model)
except Exception as e:
logging.error(f"Error Validating or Adding GPT Models: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")

Check warning on line 109 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L105-L109

Added lines #L105 - L109 were not covered by tests


@router.get("/get/list", status_code=200)
def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Expand Down
46 changes: 46 additions & 0 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.helper.encyption_helper import encrypt_data, decrypt_data
import requests, logging

# marketplace_url = "https://app.superagi.com/api"
Expand Down Expand Up @@ -201,3 +202,48 @@
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
return {"error": "Unexpected Error Occured"}

@classmethod
def validate_model_in_db(cls, session, organisation_id, model):
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the imports to top of the file.

from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration

Check warning on line 210 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L208-L210

Added lines #L208 - L210 were not covered by tests

models = {"gpt-3.5-turbo-0301": 4032, "gpt-4-0314": 8092, "gpt-3.5-turbo": 4032,

Check warning on line 212 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L212

Added line #L212 was not covered by tests
"gpt-4": 8092, "gpt-3.5-turbo-16k": 16184, "gpt-4-32k": 32768}

model_config = session.query(Models).filter(Models.model_name == model,

Check warning on line 215 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L215

Added line #L215 was not covered by tests
Models.org_id == organisation_id).first()
if model_config is None:
model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI",

Check warning on line 218 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L218

Added line #L218 was not covered by tests
ModelsConfig.org_id == organisation_id).first()

if model_provider is None:
configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key',

Check warning on line 222 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L222

Added line #L222 was not covered by tests
Configuration.organisation_id == organisation_id).first()
model_api_key = decrypt_data(configurations.value)

Check warning on line 224 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L224

Added line #L224 was not covered by tests

if configurations is None:
return {"error": "Model not found and the API Key is missing"}

Check warning on line 227 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L227

Added line #L227 was not covered by tests

model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key)

Check warning on line 229 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L229

Added line #L229 was not covered by tests

# Get 'model_provider_id'
model_provider_id = model_details.get('model_provider_id')

Check warning on line 232 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L232

Added line #L232 was not covered by tests

result = cls.store_model_details(session, organisation_id, model, model, '',

Check warning on line 234 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L234

Added line #L234 was not covered by tests
model_provider_id, models[model], 'Custom', '')
if result is not None:
return {"success": "Model was not Installed, so I have dont it for you"}

Check warning on line 237 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L237

Added line #L237 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle using HTTP code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use HTTP Exception


else:
result = cls.store_model_details(session, organisation_id, model, model, '',

Check warning on line 240 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L240

Added line #L240 was not covered by tests
model_provider.id, models[model], 'Custom', '')
if result is not None:
return {"success": "Model was not Installed, so I have dont it for you"}

Check warning on line 243 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L243

Added line #L243 was not covered by tests

else:
return {"success": "Model is found"}

Check warning on line 246 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L246

Added line #L246 was not covered by tests

except Exception as e:
logging.error(f"Unexpected Error occurred while Validating GPT Models: {e}")

Check warning on line 249 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L248-L249

Added lines #L248 - L249 were not covered by tests
11 changes: 8 additions & 3 deletions superagi/models/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@
ModelsConfig.provider == model_provider)).first()
if existing_entry:
existing_entry.api_key = encrypt_data(model_api_key)
session.commit()
result = {'message': 'The API key was successfully updated'}
else:
new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider,
api_key=encrypt_data(model_api_key))
session.add(new_entry)
session.commit()
session.flush()
result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id}

Check warning on line 87 in superagi/models/models_config.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models_config.py#L85-L87

Added lines #L85 - L87 were not covered by tests

session.commit()

return {'message': 'The API key was successfully stored'}
return result

@classmethod
def fetch_api_keys(cls, session, organisation_id):
Expand All @@ -107,6 +110,8 @@
if api_key_data is None:
return []
else:
print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print

print(decrypt_data(api_key_data.api_key))

Check warning on line 114 in superagi/models/models_config.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models_config.py#L113-L114

Added lines #L113 - L114 were not covered by tests
api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
'api_key': decrypt_data(api_key_data.api_key)}]
return api_key
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_tests/models/test_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,16 @@ def test_fetch_model_by_id(mock_session):

# Call the method
model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id)
assert model == {"provider": "some_provider"}

def test_fetch_model_by_id_marketplace(mock_session):
# Arrange
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model

# Call the method
model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id)
assert model == {"provider": "some_provider"}