Skip to content

Commit

Permalink
Llm models fix (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihiragarwal24 committed Jul 20, 2023
1 parent 8e2396c commit b58067f
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 11 deletions.
37 changes: 26 additions & 11 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
createAgent,
fetchAgentTemplateConfigLocal,
getOrganisationConfig,
getLlmModels,
updateExecution,
uploadFile
} from "@/pages/api/DashboardService";
Expand Down Expand Up @@ -46,8 +47,8 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
const [goals, setGoals] = useState(['Describe the agent goals here']);
const [instructions, setInstructions] = useState(['']);

const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001']
const [model, setModel] = useState(models[1]);
const [modelsArray, setModelsArray] = useState([]);
const [model, setModel] = useState('');
const modelRef = useRef(null);
const [modelDropdown, setModelDropdown] = useState(false);

Expand Down Expand Up @@ -88,11 +89,6 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
const [createModal, setCreateModal] = useState(false);

const [scheduleData, setScheduleData] = useState(null);
const [col6ScrollTop, setCol6ScrollTop] = useState(0);

const handleCol3Scroll = (event) => {
setCol6ScrollTop(event.target.scrollTop);
};

useEffect(() => {
getOrganisationConfig(organisationId, "model_api_key")
Expand Down Expand Up @@ -125,6 +121,21 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
}, [toolNames]);

useEffect(() => {
getLlmModels()
.then((response) => {
const models = response.data || [];
const selected_model = localStorage.getItem("agent_model_" + String(internalId)) || '';
setModelsArray(models);
if(models.length > 0 && !selected_model) {
setLocalStorageValue("agent_model_" + String(internalId), models[0], setModel);
} else {
setModel(selected_model);
}
})
.catch((error) => {
console.error('Error fetching models:', error);
});

if (template !== null) {
setLocalStorageValue("agent_name_" + String(internalId), template.name, setAgentName);
setLocalStorageValue("agent_description_" + String(internalId), template.description, setAgentDescription);
Expand Down Expand Up @@ -250,8 +261,8 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
};

const handleModelSelect = (index) => {
setLocalStorageValue("agent_model_" + String(internalId), models[index], setModel);
if (models[index] === "google-palm-bison-001") {
setLocalStorageValue("agent_model_" + String(internalId), modelsArray[index], setModel);
if (modelsArray[index] === "google-palm-bison-001") {
setAgentType("Fixed Task Queue")
}
setModelDropdown(false);
Expand Down Expand Up @@ -381,6 +392,10 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
toast.error("Add atleast one tool", {autoClose: 1800});
return
}
if(!modelsArray.includes(model)) {
toast.error("Your key does not have access to the selected model", {autoClose: 1800});
return
}

setCreateClickable(false);

Expand Down Expand Up @@ -650,7 +665,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen

return (<>
<div className="row" style={{overflowY: 'scroll', height: 'calc(100vh - 92px)'}}>
<div className="col-3" onScroll={handleCol3Scroll}></div>
<div className="col-3"></div>
<div className="col-6" style={{padding: '25px 20px'}}>
<div>
<div className={styles.page_title}>Create new agent</div>
Expand Down Expand Up @@ -721,7 +736,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
</div>
<div>
{modelDropdown && <div className="custom_select_options" ref={modelRef} style={{width: '100%'}}>
{models.map((model, index) => (
{modelsArray?.map((model, index) => (
<div key={index} className="custom_select_option" onClick={() => handleModelSelect(index)}
style={{padding: '12px 14px', maxWidth: '100%'}}>
{model}
Expand Down
4 changes: 4 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,8 @@ export const getActiveRuns = () => {

export const getToolsUsage = () => {
return api.get(`analytics/tools/used`);
}

export const getLlmModels = () => {
return api.get(`organisations/llm_models`);
}
34 changes: 34 additions & 0 deletions superagi/controllers/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
from fastapi_sqlalchemy import db
from pydantic import BaseModel

from superagi.helper.auth import get_user_organisation
from superagi.helper.auth import check_auth
from superagi.helper.encyption_helper import decrypt_data
from superagi.helper.tool_helper import register_toolkits
from superagi.llms.google_palm import GooglePalm
from superagi.llms.openai import OpenAi
from superagi.models.configuration import Configuration
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.user import User
Expand Down Expand Up @@ -35,6 +40,7 @@ class OrganisationIn(BaseModel):
class Config:
orm_mode = True


# CRUD Operations
@router.post("/add", response_model=OrganisationOut, status_code=201)
def create_organisation(organisation: OrganisationIn,
Expand Down Expand Up @@ -141,3 +147,31 @@ def get_organisations_by_user(user_id: int):
organisation = Organisation.find_or_create_organisation(db.session, user)
Project.find_or_create_default_project(db.session, organisation.id)
return organisation


@router.get("/llm_models")
def get_llm_models(organisation=Depends(get_user_organisation)):
"""
Get all the llm models associated with an organisation.
Args:
organisation: Organisation data.
"""

model_api_key = db.session.query(Configuration).filter(Configuration.organisation_id == organisation.id,
Configuration.key == "model_api_key").first()
model_source = db.session.query(Configuration).filter(Configuration.organisation_id == organisation.id,
Configuration.key == "model_source").first()

if model_api_key is None or model_source is None:
raise HTTPException(status_code=400,
detail="Organisation not found")

decrypted_api_key = decrypt_data(model_api_key.value)
models = []
if model_source.value == "OpenAi":
models = OpenAi(api_key=decrypted_api_key).get_models()
elif model_source.value == "Google Palm":
models = GooglePalm(api_key=decrypted_api_key).get_models()

return models
4 changes: 4 additions & 0 deletions superagi/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def get_api_key(self):
def get_model(self):
pass

@abstractmethod
def get_models(self):
pass

@abstractmethod
def verify_access_key(self):
pass
14 changes: 14 additions & 0 deletions superagi/llms/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ def verify_access_key(self):
except Exception as exception:
logger.info("Google palm Exception:", exception)
return False

def get_models(self):
"""
Get the models.
Returns:
list: The models.
"""
try:
models_supported = ["chat-bison-001"]
return models_supported
except Exception as exception:
logger.info("Google palm Exception:", exception)
return []
18 changes: 18 additions & 0 deletions superagi/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,21 @@ def verify_access_key(self):
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return False

def get_models(self):
"""
Get the models.
Returns:
list: The models.
"""
try:
models = openai.Model.list()
models = [model["id"] for model in models["data"]]
models_supported = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
print("CHECK THIS1", models)
models = [model for model in models if model in models_supported]
return models
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return []

1 comment on commit b58067f

@clappo143
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still no API endpoint for GPT-4-32k (yet). You can check the latest list of OpenAI's API endpoints:

curl https://api.openai.com/v1/models \
  -H "Authorization: Bearer $OPENAI_API_KEY"

On the other hand, I think it's worth noting that the gpt-3.5-turbo-16k -0613 model is currently available. While it's fine tuned for OpenAI's function calling feature (which I do not believe SuperAGI has been optimised for), my understanding is that it is also better at following system prompts than the standard gpt-3.5-turbo-16k model (which sounds very appealing/useful in the context of SuperAGI).

I've added gpt-3.5-turbo-16k**-0613** to a fork and it works, though I am yet to test it directly against the standard 3.5-16k model to see if there are any noticeable performance differences (for better or worse). Just some food for thought – if anything interesting comes out of the testing will report back.

Also, cheers for addressing the issues with implementation of the palm/bison(text/chat...confusing) model. Hopefully does the trick – look forward to giving it a try!

Please sign in to comment.