diff --git a/gui/pages/Content/Models/MarketModels.js b/gui/pages/Content/Models/MarketModels.js
index 5b864febf..383b48be7 100644
--- a/gui/pages/Content/Models/MarketModels.js
+++ b/gui/pages/Content/Models/MarketModels.js
@@ -4,6 +4,7 @@ import Image from "next/image";
import {loadingTextEffect, modelIcon, returnToolkitIcon} from "@/utils/utils";
import {EventBus} from "@/utils/eventBus";
import {fetchMarketPlaceModel} from "@/pages/api/DashboardService";
+import axios from "axios";
export default function MarketModels(){
const [showMarketplace, setShowMarketplace] = useState(false);
@@ -13,15 +14,27 @@ export default function MarketModels(){
useEffect(() => {
loadingTextEffect('Loading Models', setLoadingText, 500);
- },[]);
- useEffect(() => {
- fetchMarketPlaceModel().then((response) => {
- console.log(response.data)
- setModelTemplates(response.data)
- })
+ if (window.location.href.toLowerCase().includes('marketplace')) {
+ axios.get('https://app.superagi.com/api/models_controller/get/models_details')
+ .then((response) => {
+ setModelTemplates(response.data)
+ })
+ }
+ else {
+ fetchMarketPlaceModel().then((response) => {
+ setModelTemplates(response.data)
+ })
+ }
},[])
+ useEffect(() => {
+ if(modelTemplates.length > 0)
+ setIsLoading(true)
+ else
+ setIsLoading(false)
+ }, [modelTemplates])
+
function handleTemplateClick(item) {
const contentType = 'model_template';
EventBus.emit('openTemplateDetails', {item, contentType});
@@ -30,7 +43,7 @@ export default function MarketModels(){
return(
- {!isLoading ?
+ {isLoading ?
{modelTemplates.length > 0 ?
{modelTemplates.map((item) => (
handleTemplateClick(item)}>
{item.model_name && item.model_name.includes('/') ? item.model_name.split('/')[1] : item.model_name}
diff --git a/gui/pages/Content/Models/ModelTemplate.js b/gui/pages/Content/Models/ModelTemplate.js
index 11203739e..deb503767 100644
--- a/gui/pages/Content/Models/ModelTemplate.js
+++ b/gui/pages/Content/Models/ModelTemplate.js
@@ -10,6 +10,19 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
EventBus.emit('goToMarketplace', {});
}
+ function handleInstallClick() {
+ if (window.location.href.toLowerCase().includes('marketplace')) {
+ if (env === 'PROD') {
+ window.open(`https://app.superagi.com/`, '_self');
+ } else {
+ window.location.href = '/';
+ }
+ }
+ else {
+ setIsInstalled(true)
+ }
+ }
+
return (
isInstalled ? setIsInstalled(false) : handleBackClick()}>
@@ -20,7 +33,7 @@ export default function ModelTemplate({env, template, getModels, sendModelData})
{template.model_name}
by {template.model_name.includes('/') ? template.model_name.split('/')[0] : template.provider}
-
-
+
):(
)}
diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py
index abb771060..f3704d05c 100644
--- a/superagi/controllers/models_controller.py
+++ b/superagi/controllers/models_controller.py
@@ -5,6 +5,7 @@
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
+from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
@@ -102,7 +103,7 @@ async def fetch_data(request: ModelName, organisation=Depends(get_user_organisat
@router.get("/get/list", status_code=200)
-def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)):
+def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Get Marketplace Model list.
@@ -116,12 +117,12 @@ def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation
if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
- marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation)
+ marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation.id)
return marketplace_models_with_install
@router.get("/marketplace/list/{page}", status_code=200)
-def get_marketplace_knowledge_list(page: int = 0):
+def get_marketplace_models_list(page: int = 0):
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
if organisation_id is not None:
organisation_id = int(organisation_id)
@@ -131,4 +132,28 @@ def get_marketplace_knowledge_list(page: int = 0):
if page < 0:
models = query.all()
models = query.offset(page * page_size).limit(page_size).all()
- return models
\ No newline at end of file
+ return models
+
+
+@router.get("/get/models_details", status_code=200)
+def get_models_details(page: int = 0):
+ """
+ Get Marketplace Model list.
+
+ Args:
+ page (int, optional): The page number for pagination. Defaults to None.
+
+ Returns:
+ dict: The response containing the marketplace list.
+
+ """
+ organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
+ if organisation_id is not None:
+ organisation_id = int(organisation_id)
+
+ if page < 0:
+ page = 0
+ marketplace_models = Models.fetch_marketplace_list(page)
+ marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
+ ModelsTypes.MARKETPLACE.value)
+ return marketplace_models_with_install
\ No newline at end of file
diff --git a/superagi/controllers/types/models_types.py b/superagi/controllers/types/models_types.py
new file mode 100644
index 000000000..4f6b9cc15
--- /dev/null
+++ b/superagi/controllers/types/models_types.py
@@ -0,0 +1,14 @@
+from enum import Enum
+
+class ModelsTypes(Enum):
+ MARKETPLACE = "Marketplace"
+ CUSTOM = "Custom"
+
+ @classmethod
+ def get_models_types(cls, model_type):
+ if model_type is None:
+ raise ValueError("Queue status type cannot be None.")
+ model_type = model_type.upper()
+ if model_type in cls.__members__:
+ return cls[model_type]
+ raise ValueError(f"{model_type} is not a valid storage name.")
\ No newline at end of file
diff --git a/superagi/models/models.py b/superagi/models/models.py
index 815697298..ccd5cdcf5 100644
--- a/superagi/models/models.py
+++ b/superagi/models/models.py
@@ -2,7 +2,7 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
-from superagi.llms.openai import OpenAi
+from superagi.controllers.types.models_types import ModelsTypes
from superagi.helper.encyption_helper import decrypt_data
import requests, logging
@@ -64,9 +64,9 @@ def fetch_marketplace_list(cls, page):
return []
@classmethod
- def get_model_install_details(cls, session, marketplace_models, organisation):
+ def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value):
from superagi.models.models_config import ModelsConfig
- installed_models = session.query(Models).filter(Models.org_id == organisation.id).all()
+ installed_models = session.query(Models).filter(Models.org_id == organisation_id).all()
model_counts_dict = dict(
session.query(Models.model_name, func.count(Models.org_id)).group_by(Models.model_name).all()
)
@@ -74,7 +74,10 @@ def get_model_install_details(cls, session, marketplace_models, organisation):
for model in marketplace_models:
try:
- model["is_installed"] = installed_models_dict.get(model["model_name"], False)
+ if type == ModelsTypes.MARKETPLACE.value:
+ model["is_installed"] = False
+ else:
+ model["is_installed"] = installed_models_dict.get(model["model_name"], False)
model["installs"] = model_counts_dict.get(model["model_name"], 0)
model["provider"] = session.query(ModelsConfig).filter(
ModelsConfig.id == model["model_provider_id"]).first().provider
diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py
index 4e8adc64d..489cff636 100644
--- a/tests/unit_tests/controllers/test_models_controller.py
+++ b/tests/unit_tests/controllers/test_models_controller.py
@@ -82,7 +82,7 @@ def test_fetch_data_success(mock_get_db):
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
-def test_get_marketplace_knowledge_list_success(mock_get_db):
+def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db, \
patch('superagi.controllers.models_controller.requests.get') as mock_get:
@@ -95,7 +95,7 @@ def test_get_marketplace_knowledge_list_success(mock_get_db):
assert response.status_code == 200
@patch('superagi.controllers.models_controller.db')
-def test_get_marketplace_knowledge_list_success(mock_get_db):
+def test_get_marketplace_models_list_success(mock_get_db):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/marketplace/list/0")