diff --git a/gui/pages/Content/Models/MarketModels.js b/gui/pages/Content/Models/MarketModels.js index 5b864febf..383b48be7 100644 --- a/gui/pages/Content/Models/MarketModels.js +++ b/gui/pages/Content/Models/MarketModels.js @@ -4,6 +4,7 @@ import Image from "next/image"; import {loadingTextEffect, modelIcon, returnToolkitIcon} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; import {fetchMarketPlaceModel} from "@/pages/api/DashboardService"; +import axios from "axios"; export default function MarketModels(){ const [showMarketplace, setShowMarketplace] = useState(false); @@ -13,15 +14,27 @@ export default function MarketModels(){ useEffect(() => { loadingTextEffect('Loading Models', setLoadingText, 500); - },[]); - useEffect(() => { - fetchMarketPlaceModel().then((response) => { - console.log(response.data) - setModelTemplates(response.data) - }) + if (window.location.href.toLowerCase().includes('marketplace')) { + axios.get('https://app.superagi.com/api/models_controller/get/models_details') + .then((response) => { + setModelTemplates(response.data) + }) + } + else { + fetchMarketPlaceModel().then((response) => { + setModelTemplates(response.data) + }) + } },[]) + useEffect(() => { + if(modelTemplates.length > 0) + setIsLoading(true) + else + setIsLoading(false) + }, [modelTemplates]) + function handleTemplateClick(item) { const contentType = 'model_template'; EventBus.emit('openTemplateDetails', {item, contentType}); @@ -30,7 +43,7 @@ export default function MarketModels(){ return(
- {!isLoading ?
+ {isLoading ?
{modelTemplates.length > 0 ?
{modelTemplates.map((item) => (
handleTemplateClick(item)}>
{item.model_name && item.model_name.includes('/') ? item.model_name.split('/')[1] : item.model_name}
diff --git a/gui/pages/Content/Models/ModelTemplate.js b/gui/pages/Content/Models/ModelTemplate.js index 11203739e..deb503767 100644 --- a/gui/pages/Content/Models/ModelTemplate.js +++ b/gui/pages/Content/Models/ModelTemplate.js @@ -10,6 +10,19 @@ export default function ModelTemplate({env, template, getModels, sendModelData}) EventBus.emit('goToMarketplace', {}); } + function handleInstallClick() { + if (window.location.href.toLowerCase().includes('marketplace')) { + if (env === 'PROD') { + window.open(`https://app.superagi.com/`, '_self'); + } else { + window.location.href = '/'; + } + } + else { + setIsInstalled(true) + } + } + return (
isInstalled ? setIsInstalled(false) : handleBackClick()}> @@ -20,7 +33,7 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
{template.model_name} by {template.model_name.includes('/') ? template.model_name.split('/')[0] : template.provider} - @@ -39,7 +52,7 @@ export default function ModelTemplate({env, template, getModels, sendModelData}) Updated At {getFormattedDate(template.updated_at)}
-
+
):( )} diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index abb771060..f3704d05c 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -5,6 +5,7 @@ from superagi.models.models import Models from superagi.models.models_config import ModelsConfig from superagi.config.config import get_config +from superagi.controllers.types.models_types import ModelsTypes from fastapi_sqlalchemy import db import logging from pydantic import BaseModel @@ -102,7 +103,7 @@ async def fetch_data(request: ModelName, organisation=Depends(get_user_organisat @router.get("/get/list", status_code=200) -def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)): +def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)): """ Get Marketplace Model list. @@ -116,12 +117,12 @@ def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation if page < 0: page = 0 marketplace_models = Models.fetch_marketplace_list(page) - marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation) + marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation.id) return marketplace_models_with_install @router.get("/marketplace/list/{page}", status_code=200) -def get_marketplace_knowledge_list(page: int = 0): +def get_marketplace_models_list(page: int = 0): organisation_id = get_config("MARKETPLACE_ORGANISATION_ID") if organisation_id is not None: organisation_id = int(organisation_id) @@ -131,4 +132,28 @@ def get_marketplace_knowledge_list(page: int = 0): if page < 0: models = query.all() models = query.offset(page * page_size).limit(page_size).all() - return models \ No newline at end of file + return models + + +@router.get("/get/models_details", status_code=200) +def get_models_details(page: int = 0): + """ + Get Marketplace Model list. + + Args: + page (int, optional): The page number for pagination. Defaults to None. + + Returns: + dict: The response containing the marketplace list. + + """ + organisation_id = get_config("MARKETPLACE_ORGANISATION_ID") + if organisation_id is not None: + organisation_id = int(organisation_id) + + if page < 0: + page = 0 + marketplace_models = Models.fetch_marketplace_list(page) + marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id, + ModelsTypes.MARKETPLACE.value) + return marketplace_models_with_install \ No newline at end of file diff --git a/superagi/controllers/types/models_types.py b/superagi/controllers/types/models_types.py new file mode 100644 index 000000000..4f6b9cc15 --- /dev/null +++ b/superagi/controllers/types/models_types.py @@ -0,0 +1,14 @@ +from enum import Enum + +class ModelsTypes(Enum): + MARKETPLACE = "Marketplace" + CUSTOM = "Custom" + + @classmethod + def get_models_types(cls, model_type): + if model_type is None: + raise ValueError("Queue status type cannot be None.") + model_type = model_type.upper() + if model_type in cls.__members__: + return cls[model_type] + raise ValueError(f"{model_type} is not a valid storage name.") \ No newline at end of file diff --git a/superagi/models/models.py b/superagi/models/models.py index 815697298..ccd5cdcf5 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -2,7 +2,7 @@ from sqlalchemy.sql import func from typing import List, Dict, Union from superagi.models.base_model import DBBaseModel -from superagi.llms.openai import OpenAi +from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging @@ -64,9 +64,9 @@ def fetch_marketplace_list(cls, page): return [] @classmethod - def get_model_install_details(cls, session, marketplace_models, organisation): + def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value): from superagi.models.models_config import ModelsConfig - installed_models = session.query(Models).filter(Models.org_id == organisation.id).all() + installed_models = session.query(Models).filter(Models.org_id == organisation_id).all() model_counts_dict = dict( session.query(Models.model_name, func.count(Models.org_id)).group_by(Models.model_name).all() ) @@ -74,7 +74,10 @@ def get_model_install_details(cls, session, marketplace_models, organisation): for model in marketplace_models: try: - model["is_installed"] = installed_models_dict.get(model["model_name"], False) + if type == ModelsTypes.MARKETPLACE.value: + model["is_installed"] = False + else: + model["is_installed"] = installed_models_dict.get(model["model_name"], False) model["installs"] = model_counts_dict.get(model["model_name"], 0) model["provider"] = session.query(ModelsConfig).filter( ModelsConfig.id == model["model_provider_id"]).first().provider diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 4e8adc64d..489cff636 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -82,7 +82,7 @@ def test_fetch_data_success(mock_get_db): assert response.status_code == 200 @patch('superagi.controllers.models_controller.db') -def test_get_marketplace_knowledge_list_success(mock_get_db): +def test_get_marketplace_models_list_success(mock_get_db): with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ patch('superagi.helper.auth.db') as mock_auth_db, \ patch('superagi.controllers.models_controller.requests.get') as mock_get: @@ -95,7 +95,7 @@ def test_get_marketplace_knowledge_list_success(mock_get_db): assert response.status_code == 200 @patch('superagi.controllers.models_controller.db') -def test_get_marketplace_knowledge_list_success(mock_get_db): +def test_get_marketplace_models_list_success(mock_get_db): with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ patch('superagi.helper.auth.db') as mock_auth_db: response = client.get("/models_controller/marketplace/list/0")