Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

models marketplace changes #1219

Merged
merged 3 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions gui/pages/Content/Models/MarketModels.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Image from "next/image";
import {loadingTextEffect, modelIcon, returnToolkitIcon} from "@/utils/utils";
import {EventBus} from "@/utils/eventBus";
import {fetchMarketPlaceModel} from "@/pages/api/DashboardService";
import axios from "axios";

export default function MarketModels(){
const [showMarketplace, setShowMarketplace] = useState(false);
Expand All @@ -13,15 +14,27 @@ export default function MarketModels(){

useEffect(() => {
loadingTextEffect('Loading Models', setLoadingText, 500);
},[]);

useEffect(() => {
fetchMarketPlaceModel().then((response) => {
console.log(response.data)
setModelTemplates(response.data)
})
if (window.location.href.toLowerCase().includes('marketplace')) {
axios.get('https://app.superagi.com/api/models_controller/get/models_details')
.then((response) => {
setModelTemplates(response.data)
})
}
else {
fetchMarketPlaceModel().then((response) => {
setModelTemplates(response.data)
})
}
},[])

useEffect(() => {
if(modelTemplates.length > 0)
setIsLoading(true)
else
setIsLoading(false)
}, [modelTemplates])

function handleTemplateClick(item) {
const contentType = 'model_template';
EventBus.emit('openTemplateDetails', {item, contentType});
Expand All @@ -30,7 +43,7 @@ export default function MarketModels(){
return(
<div id="market_models" className={showMarketplace ? 'ml_8' : 'ml_3'}>
<div className="w_100 overflowY_auto mxh_78vh">
{!isLoading ? <div>
{isLoading ? <div>
{modelTemplates.length > 0 ? <div className="marketplaceGrid">{modelTemplates.map((item) => (
<div className="market_containers cursor_pointer" key={item.id} onClick={() => handleTemplateClick(item)}>
<div>{item.model_name && item.model_name.includes('/') ? item.model_name.split('/')[1] : item.model_name}</div>
Expand Down
17 changes: 15 additions & 2 deletions gui/pages/Content/Models/ModelTemplate.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
EventBus.emit('goToMarketplace', {});
}

function handleInstallClick() {
if (window.location.href.toLowerCase().includes('marketplace')) {
if (env === 'PROD') {
window.open(`https://app.superagi.com/`, '_self');
} else {
window.location.href = '/';
}
}
else {
setIsInstalled(true)
}
}

return (
<div id="model_template">
<div className="back_button mt_16 mb_16" onClick={() => isInstalled ? setIsInstalled(false) : handleBackClick()}>
Expand All @@ -20,7 +33,7 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
<div className="col_3 display_column_container padding_16">
<span className="text_20 color_white">{template.model_name}</span>
<span className="text_12 color_gray mt_4">by {template.model_name.includes('/') ? template.model_name.split('/')[0] : template.provider}</span>
<button className="primary_button w_100 mt_16" disabled={template.is_installed} onClick={() => setIsInstalled(true)}>
<button className="primary_button w_100 mt_16" disabled={template.is_installed} onClick={() => handleInstallClick()}>
<Image width={16} height={16} src={template.is_installed ? '/images/tick.svg' : '/images/marketplace_download.svg'} alt="download-icon" />
<span className="ml_8">{template.is_installed ? 'Installed' : 'Install'}</span>
</button>
Expand All @@ -39,7 +52,7 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
<span className="text_12 color_gray">Updated At</span>
<span className="text_12 color_white mt_8">{getFormattedDate(template.updated_at)}</span>
</div>
<div className="col_9 display_column_container padding_16 color_white" dangerouslySetInnerHTML={{ __html: template.model_features }} />
<div className="col_9 display_column_container padding_16 color_white text_12 lh_18" dangerouslySetInnerHTML={{ __html: template.model_features }} />
</div> ):(
<AddModelMarketPlace template={template} getModels={getModels} sendModelData={sendModelData}/>
)}
Expand Down
33 changes: 29 additions & 4 deletions superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
Expand Down Expand Up @@ -102,7 +103,7 @@


@router.get("/get/list", status_code=200)
def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)):
def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Get Marketplace Model list.

Expand All @@ -116,12 +117,12 @@
if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation.id)

Check warning on line 120 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L120

Added line #L120 was not covered by tests
return marketplace_models_with_install


@router.get("/marketplace/list/{page}", status_code=200)
def get_marketplace_knowledge_list(page: int = 0):
def get_marketplace_models_list(page: int = 0):
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
if organisation_id is not None:
organisation_id = int(organisation_id)
Expand All @@ -131,4 +132,28 @@
if page < 0:
models = query.all()
models = query.offset(page * page_size).limit(page_size).all()
return models
return models


@router.get("/get/models_details", status_code=200)
def get_models_details(page: int = 0):
"""
Get Marketplace Model list.

Args:
page (int, optional): The page number for pagination. Defaults to None.

Returns:
dict: The response containing the marketplace list.

"""
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")

Check warning on line 150 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L150

Added line #L150 was not covered by tests
if organisation_id is not None:
organisation_id = int(organisation_id)

Check warning on line 152 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L152

Added line #L152 was not covered by tests

if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,

Check warning on line 157 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L155-L157

Added lines #L155 - L157 were not covered by tests
ModelsTypes.MARKETPLACE.value)
return marketplace_models_with_install

Check warning on line 159 in superagi/controllers/models_controller.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/models_controller.py#L159

Added line #L159 was not covered by tests
14 changes: 14 additions & 0 deletions superagi/controllers/types/models_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum

class ModelsTypes(Enum):
MARKETPLACE = "Marketplace"
CUSTOM = "Custom"

@classmethod
def get_models_types(cls, model_type):
if model_type is None:
raise ValueError("Queue status type cannot be None.")
model_type = model_type.upper()

Check warning on line 11 in superagi/controllers/types/models_types.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/types/models_types.py#L10-L11

Added lines #L10 - L11 were not covered by tests
if model_type in cls.__members__:
return cls[model_type]
raise ValueError(f"{model_type} is not a valid storage name.")

Check warning on line 14 in superagi/controllers/types/models_types.py

View check run for this annotation

Codecov / codecov/patch

superagi/controllers/types/models_types.py#L13-L14

Added lines #L13 - L14 were not covered by tests
11 changes: 7 additions & 4 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.llms.openai import OpenAi
from superagi.controllers.types.models_types import ModelsTypes
from superagi.helper.encyption_helper import decrypt_data
import requests, logging

Expand Down Expand Up @@ -64,17 +64,20 @@
return []

@classmethod
def get_model_install_details(cls, session, marketplace_models, organisation):
def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value):
from superagi.models.models_config import ModelsConfig
installed_models = session.query(Models).filter(Models.org_id == organisation.id).all()
installed_models = session.query(Models).filter(Models.org_id == organisation_id).all()

Check warning on line 69 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L69

Added line #L69 was not covered by tests
model_counts_dict = dict(
session.query(Models.model_name, func.count(Models.org_id)).group_by(Models.model_name).all()
)
installed_models_dict = {model.model_name: True for model in installed_models}

for model in marketplace_models:
try:
model["is_installed"] = installed_models_dict.get(model["model_name"], False)
if type == ModelsTypes.MARKETPLACE.value:
model["is_installed"] = False

Check warning on line 78 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L78

Added line #L78 was not covered by tests
else:
model["is_installed"] = installed_models_dict.get(model["model_name"], False)

Check warning on line 80 in superagi/models/models.py

View check run for this annotation

Codecov / codecov/patch

superagi/models/models.py#L80

Added line #L80 was not covered by tests
model["installs"] = model_counts_dict.get(model["model_name"], 0)
model["provider"] = session.query(ModelsConfig).filter(
ModelsConfig.id == model["model_provider_id"]).first().provider
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/controllers/test_models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_fetch_data_success(mock_get_db):
assert response.status_code == 200

@patch('superagi.controllers.models_controller.db')
def test_get_marketplace_knowledge_list_success(mock_get_db):
def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.models_controller.requests.get') as mock_get:
Expand All @@ -95,7 +95,7 @@ def test_get_marketplace_knowledge_list_success(mock_get_db):
assert response.status_code == 200

@patch('superagi.controllers.models_controller.db')
def test_get_marketplace_knowledge_list_success(mock_get_db):
def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/marketplace/list/0")
Expand Down