From b441007dc40f0a2229d4abf3e324c7f71a287cab Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 9 Aug 2023 11:22:27 +0530 Subject: [PATCH 01/34] Added code to connect weaviate vectoro db --- superagi/controllers/vector_db_indices.py | 10 ++++++++-- superagi/controllers/vector_dbs.py | 20 +++++++++++++++++++ .../vector_embedding_factory.py | 5 ++++- superagi/vector_embeddings/weaviate.py | 18 +++++++++++++++++ superagi/vector_store/vector_factory.py | 9 ++++++++- superagi/vector_store/weaviate.py | 6 ++++-- 6 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 superagi/vector_embeddings/weaviate.py diff --git a/superagi/controllers/vector_db_indices.py b/superagi/controllers/vector_db_indices.py index 667bc953e..59231348c 100644 --- a/superagi/controllers/vector_db_indices.py +++ b/superagi/controllers/vector_db_indices.py @@ -16,6 +16,7 @@ def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(ge knowledge_with_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge['id']) pinecone = [] qdrant = [] + weaviate = [] for vector_db in vector_dbs: indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id) for index in indices: @@ -26,13 +27,16 @@ def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(ge pinecone.append(data) if vector_db.db_type == "Qdrant": qdrant.append(data) - return {"pinecone": pinecone, "qdrant": qdrant} + if vector_db.db_type == "Weaviate": + weaviate.append(data) + return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate} @router.get("/user/valid_indices") def get_user_valid_indices(organisation = Depends(get_user_organisation)): vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation) pinecone = [] qdrant = [] + weaviate = [] for vector_db in vector_dbs: indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id) for index in indices: @@ -42,4 +46,6 @@ def get_user_valid_indices(organisation = Depends(get_user_organisation)): pinecone.append(data) if vector_db.db_type == "Qdrant": qdrant.append(data) - return {"pinecone": pinecone, "qdrant": qdrant} \ No newline at end of file + if vector_db.db_type == "Weaviate": + weaviate.append(data) + return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate} \ No newline at end of file diff --git a/superagi/controllers/vector_dbs.py b/superagi/controllers/vector_dbs.py index 66e49ff63..4b2286e6c 100644 --- a/superagi/controllers/vector_dbs.py +++ b/superagi/controllers/vector_dbs.py @@ -101,6 +101,26 @@ def connect_qdrant_vector_db(data: dict, organisation = Depends(get_user_organis return {"id": qdrant_db.id, "name": qdrant_db.name} +@router.post("/connect/weaviate") +def connect_weaviate_vector_db(data: dict, organisation = Depends(get_user_organisation)): + db_creds = { + "api_key": data["api_key"], + "url": data["url"] + } + for collection in data["collections"]: + try: + vector_db_storage = VectorFactory.build_vector_storage("weaviate", collection, **db_creds) + db_connect_for_index = vector_db_storage.get_index_stats() + index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None" + except: + raise HTTPException(status_code=400, detail="Unable to connect Weaviate") + weaviate_db = Vectordbs.add_vector_db(db.session, data["name"], "Weaviate", organisation) + VectordbConfigs.add_vector_db_config(db.session, weaviate_db.id, db_creds) + for collection in data["collections"]: + VectordbIndices.add_vector_index(db.session, collection, weaviate_db.id, data["dimensions"], index_state) + + return {"id": weaviate_db.id, "name": weaviate_db.name} + @router.put("/update/vector_db/{vector_db_id}") def update_vector_db(new_indices: list, vector_db_id: int): vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id) diff --git a/superagi/vector_embeddings/vector_embedding_factory.py b/superagi/vector_embeddings/vector_embedding_factory.py index 827edb79c..0c4e2984c 100644 --- a/superagi/vector_embeddings/vector_embedding_factory.py +++ b/superagi/vector_embeddings/vector_embedding_factory.py @@ -40,4 +40,7 @@ def build_vector_storage(cls, vector_store: VectorStoreType, chunk_json: Optiona return Pinecone(uuid, embeds, metadata) if vector_store == VectorStoreType.QDRANT: - return Qdrant(uuid, embeds, metadata) \ No newline at end of file + return Qdrant(uuid, embeds, metadata) + + if vector_store == VectorStoreType.WEAVIATE: + return Weaviate(uuid, embeds, metadata) \ No newline at end of file diff --git a/superagi/vector_embeddings/weaviate.py b/superagi/vector_embeddings/weaviate.py new file mode 100644 index 000000000..9d77bd15b --- /dev/null +++ b/superagi/vector_embeddings/weaviate.py @@ -0,0 +1,18 @@ +from typing import Any +from superagi.vector_embeddings.base import VectorEmbeddings + +class Weaviate(VectorEmbeddings): + + def __init__(self, uuid, embeds, metadata): + self.uuid = uuid + self.embeds = embeds + self.metadata = metadata + + def get_vector_embeddings_from_chunks(self): + """ Returns embeddings for vector dbs from final chunks""" + result = {} + result['ids'] = self.uuid + result['data_object'] = self.metadata + result['vectors'] = self.embeds + + return result \ No newline at end of file diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py index 4d9dbd157..79206939a 100644 --- a/superagi/vector_store/vector_factory.py +++ b/superagi/vector_store/vector_factory.py @@ -93,4 +93,11 @@ def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embeddi client = qdrant.create_qdrant_client(creds["api_key"], creds["url"], creds["port"]) return qdrant.Qdrant(client, embedding_model, index_name) except: - raise ValueError("Qdrant API key not found") \ No newline at end of file + raise ValueError("Qdrant API key not found") + + if vector_store == VectorStoreType.WEAVIATE: + try: + client = weaviate.create_weaviate_client(creds["url"], creds["api_key"]) + return weaviate.Weaviate(client, embedding_model, index_name) + except: + raise ValueError("Weaviate API key not found") \ No newline at end of file diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index e9fdc236b..b439a4a2e 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -45,7 +45,7 @@ def create_weaviate_client( class Weaviate(VectorStore): def __init__( - self, client: weaviate.Client, embedding_model: Any, index: str, text_field: str + self, client: weaviate.Client, embedding_model: Any, index: str, text_field: str = "text" ): self.index = index self.embedding_model = embedding_model @@ -106,7 +106,9 @@ def _get_metadata_fields(self) -> List[str]: return property_names def get_index_stats(self) -> dict: - pass + result = self.client.query.get(self.index).with_meta_count().do() + vector_count = result['data']['Aggregate'][self.index][0]['meta']['count'] + return {'vector_count': vector_count} def add_embeddings_to_vector_db(self, embeddings: dict) -> None: pass From f78396c0ce6594d0a669245db306904249a20334 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 9 Aug 2023 11:54:02 +0530 Subject: [PATCH 02/34] add embeddings to vector_db weaviate --- superagi/vector_embeddings/vector_embedding_factory.py | 1 + superagi/vector_store/weaviate.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/superagi/vector_embeddings/vector_embedding_factory.py b/superagi/vector_embeddings/vector_embedding_factory.py index 0c4e2984c..908d499cd 100644 --- a/superagi/vector_embeddings/vector_embedding_factory.py +++ b/superagi/vector_embeddings/vector_embedding_factory.py @@ -4,6 +4,7 @@ from pinecone import UnauthorizedException from superagi.vector_embeddings.pinecone import Pinecone from superagi.vector_embeddings.qdrant import Qdrant +from superagi.vector_embeddings.weaviate import Weaviate from superagi.types.vector_store_types import VectorStoreType class VectorEmbeddingFactory: diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index b439a4a2e..ce822dd32 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -111,7 +111,13 @@ def get_index_stats(self) -> dict: return {'vector_count': vector_count} def add_embeddings_to_vector_db(self, embeddings: dict) -> None: - pass - + try: + with self.client.batch as batch: + for i in range(len(embeddings['ids'])): + data_object = {key: value for key, value in embeddings['data_object'][i].items()} + batch.add_data_object(data_object, class_name=self.index, uuid=embeddings['ids'][i], vector=embeddings['vectors'][i]) + except Exception as err: + raise err + def delete_embeddings_from_vector_db(self, ids: List[str]) -> None: pass \ No newline at end of file From d06c2a79085316175896b2b9d3a30703a00cd71c Mon Sep 17 00:00:00 2001 From: namansleeps Date: Wed, 9 Aug 2023 13:17:17 +0530 Subject: [PATCH 03/34] weaviate frontend complete --- .../Content/Marketplace/KnowledgeTemplate.js | 24 +++++++ gui/pages/Dashboard/Settings/AddDatabase.js | 71 ++++++++++++++++++- gui/pages/_app.css | 1 + gui/pages/api/DashboardService.js | 4 ++ gui/public/images/weaviate.svg | 9 +++ gui/utils/utils.js | 3 +- 6 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 gui/public/images/weaviate.svg diff --git a/gui/pages/Content/Marketplace/KnowledgeTemplate.js b/gui/pages/Content/Marketplace/KnowledgeTemplate.js index d83551e60..2ccbebe4a 100644 --- a/gui/pages/Content/Marketplace/KnowledgeTemplate.js +++ b/gui/pages/Content/Marketplace/KnowledgeTemplate.js @@ -25,6 +25,7 @@ export default function KnowledgeTemplate({template, env}) { const [indexDropdown, setIndexDropdown] = useState(false); const [pinconeIndices, setPineconeIndices] = useState([]); const [qdrantIndices, setQdrantIndices] = useState([]); + const [weaviateIndices, setWeaviateIndices] = useState([]); useEffect(() => { getValidMarketplaceIndices(template.name) @@ -33,6 +34,7 @@ export default function KnowledgeTemplate({template, env}) { if (data) { setPineconeIndices(data.pinecone || []); setQdrantIndices(data.qdrant || []); + setWeaviateIndices(data.weaviate || []) } }) .catch((error) => { @@ -239,6 +241,28 @@ export default function KnowledgeTemplate({template, env}) { } ))} } + {weaviateIndices && weaviateIndices.length > 0 && +
+
Weaviate
+ {weaviateIndices.map((index) => (
handleInstallClick(index.id)} style={{ + padding: '12px 14px', + maxWidth: '100%', + display: 'flex', + justifyContent: 'space-between' + }}> +
{index.name}
+ {!checkIndexValidity(index.is_valid_state, index.is_valid_dimension)[0] && +
+ info-icon +
} +
))} +
} } } diff --git a/gui/pages/Dashboard/Settings/AddDatabase.js b/gui/pages/Dashboard/Settings/AddDatabase.js index 0383f972a..7e8a19c9e 100644 --- a/gui/pages/Dashboard/Settings/AddDatabase.js +++ b/gui/pages/Dashboard/Settings/AddDatabase.js @@ -13,7 +13,7 @@ import knowledgeStyles from "@/pages/Content/Knowledge/Knowledge.module.css"; import styles from "@/pages/Content/Marketplace/Market.module.css"; import Image from "next/image"; import styles1 from "@/pages/Content/Agents/Agents.module.css"; -import {connectPinecone, connectQdrant, fetchVectorDBList} from "@/pages/api/DashboardService"; +import {connectPinecone, connectQdrant, connectWeaviate, fetchVectorDBList} from "@/pages/api/DashboardService"; export default function AddDatabase({internalId, sendDatabaseDetailsData}) { const [activeView, setActiveView] = useState('select_database'); @@ -27,6 +27,10 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { const [qdrantApiKey, setQdrantApiKey] = useState(''); const [qdrantURL, setQdrantURL] = useState(''); + + const [weaviateApiKey, setWeaviateApiKey] = useState(''); + const [weaviateURL, setWeaviateURL] = useState(''); + const [qdrantPort, setQdrantPort] = useState(8001); const [connectText, setConnectText] = useState('Connect'); @@ -70,6 +74,17 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { if (qdrant_port) { setQdrantPort(Number(qdrant_port)); } + + const weaviate_api = localStorage.getItem('weaviate_api_' + String(internalId)); + if (weaviate_api) { + setWeaviateApiKey(weaviate_api); + } + + const weaviate_url = localStorage.getItem('weaviate_url_' + String(internalId)); + if (weaviate_url) { + setWeaviateURL(weaviate_url); + } + }, [internalId]); useEffect(() => { @@ -109,6 +124,14 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { setLocalStorageValue('qdrant_port_' + String(internalId), event.target.value, setQdrantPort); } + const handleWeaviateAPIKeyChange = (event) => { + setLocalStorageValue('weaviate_api_' + String(internalId), event.target.value, setWeaviateApiKey); + } + + const handleWeaviateURLChange = (event) => { + setLocalStorageValue('weaviate_url_' + String(internalId), event.target.value, setWeaviateURL); + } + const addCollection = () => { setLocalStorageArray("db_collections_" + String(internalId), [...collections, 'collection name'], setCollections); }; @@ -140,6 +163,11 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { return; } + if(collections.length === 1 && collections[0].length < 1){ + toast.error("Atleast add 1 Collection/Index", {autoClose: 1800}); + return; + } + if (selectedDB === 'Pinecone') { if (pineconeApiKey.replace(/\s/g, '') === '') { toast.error("Pinecone API key is empty", {autoClose: 1800}); @@ -207,6 +235,37 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { setConnectText("Connect"); }); } + + if (selectedDB === 'Weaviate') { + if (weaviateApiKey.replace(/\s/g, '') === '') { + toast.error("Weaviate API key is empty", {autoClose: 1800}); + return; + } + + if (weaviateURL.replace(/\s/g, '') === '') { + toast.error("Weaviate URL is empty", {autoClose: 1800}); + return; + } + + setConnectText("Connecting..."); + + const weaviateData = { + "name": databaseName, + "collections": collections, + "api_key": weaviateApiKey, + "url": weaviateURL, + } + + connectWeaviate(weaviateData) + .then((response) => { + connectResponse(response.data); + }) + .catch((error) => { + toast.error("Unable to connect database", {autoClose: 1800}); + console.error('Error fetching vector databases:', error); + setConnectText("Connect"); + }); + } } const proceedAddDatabase = () => { @@ -321,6 +380,16 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) { } + {selectedDB === 'Weaviate' &&
+
+ + +
+
+ + +
+
}
} ))} } + {weaviateIndices && weaviateIndices.length > 0 && +
+
Weaviate
+ {weaviateIndices.map((index) => (
handleIndexSelect(index)}> +
{index.name}
+ {!checkIndexValidity(index.is_valid_state)[0] && +
+ info-icon +
} +
))} +
} } From f043d5562422070ca06f076797e8f89214a74394 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 9 Aug 2023 13:39:17 +0530 Subject: [PATCH 05/34] code for uninstall --- superagi/vector_store/weaviate.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index ce822dd32..f496d4892 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -120,4 +120,11 @@ def add_embeddings_to_vector_db(self, embeddings: dict) -> None: raise err def delete_embeddings_from_vector_db(self, ids: List[str]) -> None: - pass \ No newline at end of file + try: + for id in ids: + self.client.data_object.delete( + uuid = id, + class_name = self.index + ) + except Exception as err: + raise err \ No newline at end of file From b581735a4620d92272de8562b0996149f6da40f0 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 9 Aug 2023 18:25:03 +0530 Subject: [PATCH 06/34] Dimensions change for weaviate --- superagi/controllers/vector_dbs.py | 8 ++++---- superagi/models/vector_db_indices.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/superagi/controllers/vector_dbs.py b/superagi/controllers/vector_dbs.py index 4b2286e6c..38c337a4b 100644 --- a/superagi/controllers/vector_dbs.py +++ b/superagi/controllers/vector_dbs.py @@ -77,7 +77,7 @@ def connect_pinecone_vector_db(data: dict, organisation = Depends(get_user_organ pinecone_db = Vectordbs.add_vector_db(db.session, data["name"], "Pinecone", organisation) VectordbConfigs.add_vector_db_config(db.session, pinecone_db.id, db_creds) for collection in data["collections"]: - VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, db_connect_for_index["dimensions"], index_state) + VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, index_state, db_connect_for_index["dimensions"]) return {"id": pinecone_db.id, "name": pinecone_db.name} @router.post("/connect/qdrant") @@ -97,7 +97,7 @@ def connect_qdrant_vector_db(data: dict, organisation = Depends(get_user_organis qdrant_db = Vectordbs.add_vector_db(db.session, data["name"], "Qdrant", organisation) VectordbConfigs.add_vector_db_config(db.session, qdrant_db.id, db_creds) for collection in data["collections"]: - VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, db_connect_for_index["dimensions"], index_state) + VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, index_state, db_connect_for_index["dimensions"]) return {"id": qdrant_db.id, "name": qdrant_db.name} @@ -117,7 +117,7 @@ def connect_weaviate_vector_db(data: dict, organisation = Depends(get_user_organ weaviate_db = Vectordbs.add_vector_db(db.session, data["name"], "Weaviate", organisation) VectordbConfigs.add_vector_db_config(db.session, weaviate_db.id, db_creds) for collection in data["collections"]: - VectordbIndices.add_vector_index(db.session, collection, weaviate_db.id, data["dimensions"], index_state) + VectordbIndices.add_vector_index(db.session, collection, weaviate_db.id, index_state) return {"id": weaviate_db.id, "name": weaviate_db.name} @@ -141,7 +141,7 @@ def update_vector_db(new_indices: list, vector_db_id: int): index_state = "Custom" if vector_db_index_stats["vector_count"] > 0 else "None" except: raise HTTPException(status_code=400, detail="Unable to update vector db") - VectordbIndices.add_vector_index(db.session, index, vector_db_id, vector_db_index_stats["dimensions"], index_state) + VectordbIndices.add_vector_index(db.session, index, vector_db_id, index_state, vector_db_index_stats["dimensions"]) diff --git a/superagi/models/vector_db_indices.py b/superagi/models/vector_db_indices.py index 5f522973f..eaba8b669 100644 --- a/superagi/models/vector_db_indices.py +++ b/superagi/models/vector_db_indices.py @@ -47,7 +47,7 @@ def delete_vector_db_index(cls, session, vector_index_id): session.commit() @classmethod - def add_vector_index(cls, session, index_name, vector_db_id, dimensions, state): + def add_vector_index(cls, session, index_name, vector_db_id, state, dimensions = None): vector_index = VectordbIndices(name=index_name, vector_db_id=vector_db_id, dimensions=dimensions, state=state) session.add(vector_index) session.commit() From b04404573050851839fc97252bc8506352ab8419 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 9 Aug 2023 19:39:07 +0530 Subject: [PATCH 07/34] Fix for getting index stats --- superagi/vector_store/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index f496d4892..d0ba8e746 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -106,7 +106,7 @@ def _get_metadata_fields(self) -> List[str]: return property_names def get_index_stats(self) -> dict: - result = self.client.query.get(self.index).with_meta_count().do() + result = self.client.query.aggregate(self.index).with_meta_count().do() vector_count = result['data']['Aggregate'][self.index][0]['meta']['count'] return {'vector_count': vector_count} From 9e5e87cb403170898231bc367e91d2e6921a606d Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 11:28:53 +0530 Subject: [PATCH 08/34] Is valid dimension changed --- superagi/controllers/vector_db_indices.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superagi/controllers/vector_db_indices.py b/superagi/controllers/vector_db_indices.py index 59231348c..710557712 100644 --- a/superagi/controllers/vector_db_indices.py +++ b/superagi/controllers/vector_db_indices.py @@ -28,6 +28,7 @@ def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(ge if vector_db.db_type == "Qdrant": qdrant.append(data) if vector_db.db_type == "Weaviate": + data["is_valid_dimension"] = True weaviate.append(data) return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate} From 2e388447aaa1032ef7529620e5a2f161d5c69b6b Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 13:23:46 +0530 Subject: [PATCH 09/34] change in creation of client --- superagi/vector_store/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index d0ba8e746..33d3a972f 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -10,7 +10,7 @@ def create_weaviate_client( - use_embedded: bool = True, + use_embedded: bool = False, url: Optional[str] = None, api_key: Optional[str] = None, ) -> weaviate.Client: From 897a8e8bec94c7b9be808e9546da0f68a1e603cc Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 13:59:30 +0530 Subject: [PATCH 10/34] Made changes --- superagi/vector_store/weaviate.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 33d3a972f..67cfcbe92 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -10,7 +10,6 @@ def create_weaviate_client( - use_embedded: bool = False, url: Optional[str] = None, api_key: Optional[str] = None, ) -> weaviate.Client: @@ -28,9 +27,7 @@ def create_weaviate_client( Raises: ValueError: If invalid argument combination are passed. """ - if use_embedded: - client = weaviate.Client(embedded_options=weaviate.embedded.EmbeddedOptions()) - elif url: + if url: if api_key: auth_config = weaviate.AuthApiKey(api_key=api_key) else: From 077582397f57f3bbf19708ad6592a3413670e999 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 16:08:40 +0530 Subject: [PATCH 11/34] Vector store add tests fixed for weaviate --- superagi/vector_store/weaviate.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 67cfcbe92..66f1b49f7 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import weaviate - +from uuid import uuid4 from superagi.vector_store.base import VectorStore from superagi.vector_store.document import Document @@ -53,19 +53,21 @@ def __init__( def add_texts( self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any ) -> List[str]: - result = [] - with self.client.batch as batch: - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - data_object = metadata.copy() - data_object[self.text_field] = text - vector = self.embedding_model.get_embedding(text) - - batch.add_data_object(data_object, class_name=self.index, vector=vector) - - object = batch.create_objects()[0] - result.append(object["id"]) - return result + result = {} + collected_ids = [] + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + data_object = metadata.copy() + data_object[self.text_field] = text + vector = self.embedding_model.get_embedding(text) + id = str(uuid4()) + result = {"ids": id, "data_object": data_object, "vectors": vector} + collected_ids.append(id) + try: + self.add_embeddings_to_vector_db(result) + except: + raise Exception("Error adding embeddings to vector db") + return collected_ids def get_matching_text( self, query: str, top_k: int = 5, **kwargs: Any From cf252bfaaddf19a171b283bc0c17d1a7389b141b Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 16:38:51 +0530 Subject: [PATCH 12/34] Fixes for get matching text --- .../knowledge_search/knowledge_search.py | 2 +- superagi/vector_store/weaviate.py | 35 ++++++++----------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/superagi/tools/knowledge_search/knowledge_search.py b/superagi/tools/knowledge_search/knowledge_search.py index 45958bf40..b7ff04715 100644 --- a/superagi/tools/knowledge_search/knowledge_search.py +++ b/superagi/tools/knowledge_search/knowledge_search.py @@ -49,7 +49,7 @@ def _execute(self, query: str): embedding_model = AgentExecutor.get_embedding(model_source, model_api_key) try: if vector_db_index.state == "Custom": - filters = {} + filters = None if vector_db_index.state == "Marketplace": filters = {"knowledge_name": knowledge.name} vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, vector_db_index.name, embedding_model, **db_creds) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 66f1b49f7..414485d02 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -70,19 +70,23 @@ def add_texts( return collected_ids def get_matching_text( - self, query: str, top_k: int = 5, **kwargs: Any + self, query: str, top_k: int = 5, metadata: dict = None, **kwargs: Any ) -> List[Document]: - alpha = kwargs.get("alpha", 0.5) - metadata_fields = self._get_metadata_fields() query_vector = self.embedding_model.get_embedding(query) - - results = ( - self.client.query.get(self.index, metadata_fields + [self.text_field]) - .with_hybrid(query, vector=query_vector, alpha=alpha) - .with_limit(top_k) - .do() - ) - + if metadata is not None: + for key, value in metadata.items(): + filters = { + "path": [key], + "operator": "Equal", + + } + + results = self.client.query.get( + self.index, + [self.text_field], + ).with_near_vector( + {"vector": query_vector, "certainty": 0.7} + ).with_where(filters).with_limit(top_k).do() results_data = results["data"]["Get"][self.index] documents = [] for result in results_data: @@ -95,15 +99,6 @@ def get_matching_text( return documents - def _get_metadata_fields(self) -> List[str]: - schema = self.client.schema.get(self.index) - property_names = [] - for property_schema in schema["properties"]: - property_names.append(property_schema["name"]) - - property_names.remove(self.text_field) - return property_names - def get_index_stats(self) -> dict: result = self.client.query.aggregate(self.index).with_meta_count().do() vector_count = result['data']['Aggregate'][self.index][0]['meta']['count'] From 3c8cb5c9879ae62b5bca77b971720230272ae8ae Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 17:17:22 +0530 Subject: [PATCH 13/34] fixed get_matching_texts --- superagi/vector_store/weaviate.py | 47 +++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 414485d02..47ef62091 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -72,6 +72,7 @@ def add_texts( def get_matching_text( self, query: str, top_k: int = 5, metadata: dict = None, **kwargs: Any ) -> List[Document]: + metadata_fields = self._get_metadata_fields() query_vector = self.embedding_model.get_embedding(query) if metadata is not None: for key, value in metadata.items(): @@ -83,21 +84,25 @@ def get_matching_text( results = self.client.query.get( self.index, - [self.text_field], + metadata_fields + [self.text_field], ).with_near_vector( {"vector": query_vector, "certainty": 0.7} ).with_where(filters).with_limit(top_k).do() + results_data = results["data"]["Get"][self.index] - documents = [] - for result in results_data: - text_content = result[self.text_field] - metadata = {} - for field in metadata_fields: - metadata[field] = result[field] - document = Document(text_content=text_content, metadata=metadata) - documents.append(document) + search_res = self._get_search_res(results_data, query) + documents = self._build_documents(results_data, metadata_fields) - return documents + return {"search_res": search_res, "documents": documents} + + def _get_metadata_fields(self) -> List[str]: + schema = self.client.schema.get(self.index) + property_names = [] + for property_schema in schema["properties"]: + property_names.append(property_schema["name"]) + + property_names.remove(self.text_field) + return property_names def get_index_stats(self) -> dict: result = self.client.query.aggregate(self.index).with_meta_count().do() @@ -121,4 +126,24 @@ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None: class_name = self.index ) except Exception as err: - raise err \ No newline at end of file + raise err + + def _build_documents(self, results_data, metadata_fields) -> List[Document]: + documents = [] + for result in results_data: + text_content = result[self.text_field] + metadata = {} + for field in metadata_fields: + metadata[field] = result[field] + document = Document(text_content=text_content, metadata=metadata) + documents.append(document) + + return documents + + def _get_search_res(self, results, query): + text = [item['text'] for item in results['data']['Get']['Knowledge']] + search_res = f"Query: {text}\n" + for context in text: + search_res += f"Chunk{i}: \n{context}\n" + i += 1 + return search_res \ No newline at end of file From fc6a01885ba904a4f96cd63f07666f0d2455042e Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 17:19:52 +0530 Subject: [PATCH 14/34] changes --- superagi/vector_store/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 47ef62091..7000a39a3 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -79,7 +79,7 @@ def get_matching_text( filters = { "path": [key], "operator": "Equal", - + "valueString": value } results = self.client.query.get( From f1496d718891c84327d55dd722294c6522900cc4 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 17:51:19 +0530 Subject: [PATCH 15/34] Minor fixes --- superagi/vector_store/weaviate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 7000a39a3..8da757a61 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -141,8 +141,8 @@ def _build_documents(self, results_data, metadata_fields) -> List[Document]: return documents def _get_search_res(self, results, query): - text = [item['text'] for item in results['data']['Get']['Knowledge']] - search_res = f"Query: {text}\n" + text = [item['text'] for item in results] + search_res = f"Query: {query}\n" for context in text: search_res += f"Chunk{i}: \n{context}\n" i += 1 From 0b08f3f10911ee6b82b63e147e26fe519c440d01 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Thu, 10 Aug 2023 17:58:51 +0530 Subject: [PATCH 16/34] Fix for extra variable --- superagi/vector_store/weaviate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index 8da757a61..f6a5a6fd5 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -143,6 +143,7 @@ def _build_documents(self, results_data, metadata_fields) -> List[Document]: def _get_search_res(self, results, query): text = [item['text'] for item in results] search_res = f"Query: {query}\n" + i = 0 for context in text: search_res += f"Chunk{i}: \n{context}\n" i += 1 From 989c62bd4a107adf84d71a8c05cd066bc56d28c4 Mon Sep 17 00:00:00 2001 From: namansleeps Date: Thu, 10 Aug 2023 18:24:40 +0530 Subject: [PATCH 17/34] minor bug --- gui/pages/Content/Knowledge/KnowledgeForm.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui/pages/Content/Knowledge/KnowledgeForm.js b/gui/pages/Content/Knowledge/KnowledgeForm.js index 46f222c7b..7ba785812 100644 --- a/gui/pages/Content/Knowledge/KnowledgeForm.js +++ b/gui/pages/Content/Knowledge/KnowledgeForm.js @@ -34,7 +34,7 @@ export default function KnowledgeForm({ if (data) { setPineconeIndices(data.pinecone || []); setQdrantIndices(data.qdrant || []); - setWeaviateIndices(data.qdrant || []); + setWeaviateIndices(data.weaviate || []); } }) .catch((error) => { From 87dc1984fc02214c36bd4e9c627d82c9001b414b Mon Sep 17 00:00:00 2001 From: Tarraann Date: Fri, 11 Aug 2023 12:01:05 +0530 Subject: [PATCH 18/34] PR Comments resolved --- superagi/vector_store/weaviate.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py index f6a5a6fd5..e9f057851 100644 --- a/superagi/vector_store/weaviate.py +++ b/superagi/vector_store/weaviate.py @@ -42,9 +42,9 @@ def create_weaviate_client( class Weaviate(VectorStore): def __init__( - self, client: weaviate.Client, embedding_model: Any, index: str, text_field: str = "text" + self, client: weaviate.Client, embedding_model: Any, class_name: str, text_field: str = "text" ): - self.index = index + self.class_name = class_name self.embedding_model = embedding_model self.text_field = text_field @@ -63,10 +63,7 @@ def add_texts( id = str(uuid4()) result = {"ids": id, "data_object": data_object, "vectors": vector} collected_ids.append(id) - try: - self.add_embeddings_to_vector_db(result) - except: - raise Exception("Error adding embeddings to vector db") + self.add_embeddings_to_vector_db(result) return collected_ids def get_matching_text( @@ -83,20 +80,20 @@ def get_matching_text( } results = self.client.query.get( - self.index, + self.class_name, metadata_fields + [self.text_field], ).with_near_vector( {"vector": query_vector, "certainty": 0.7} ).with_where(filters).with_limit(top_k).do() - results_data = results["data"]["Get"][self.index] + results_data = results["data"]["Get"][self.class_name] search_res = self._get_search_res(results_data, query) documents = self._build_documents(results_data, metadata_fields) return {"search_res": search_res, "documents": documents} def _get_metadata_fields(self) -> List[str]: - schema = self.client.schema.get(self.index) + schema = self.client.schema.get(self.class_name) property_names = [] for property_schema in schema["properties"]: property_names.append(property_schema["name"]) @@ -105,8 +102,8 @@ def _get_metadata_fields(self) -> List[str]: return property_names def get_index_stats(self) -> dict: - result = self.client.query.aggregate(self.index).with_meta_count().do() - vector_count = result['data']['Aggregate'][self.index][0]['meta']['count'] + result = self.client.query.aggregate(self.class_name).with_meta_count().do() + vector_count = result['data']['Aggregate'][self.class_name][0]['meta']['count'] return {'vector_count': vector_count} def add_embeddings_to_vector_db(self, embeddings: dict) -> None: @@ -114,7 +111,7 @@ def add_embeddings_to_vector_db(self, embeddings: dict) -> None: with self.client.batch as batch: for i in range(len(embeddings['ids'])): data_object = {key: value for key, value in embeddings['data_object'][i].items()} - batch.add_data_object(data_object, class_name=self.index, uuid=embeddings['ids'][i], vector=embeddings['vectors'][i]) + batch.add_data_object(data_object, class_name=self.class_name, uuid=embeddings['ids'][i], vector=embeddings['vectors'][i]) except Exception as err: raise err @@ -123,7 +120,7 @@ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None: for id in ids: self.client.data_object.delete( uuid = id, - class_name = self.index + class_name = self.class_name ) except Exception as err: raise err From 5320b1298c322c0d09ef30adb83dc438314a9363 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Fri, 11 Aug 2023 12:29:20 +0530 Subject: [PATCH 19/34] Tests for weaviate.py --- .../vector_store/test_weaviate.py | 163 +++++++----------- 1 file changed, 67 insertions(+), 96 deletions(-) diff --git a/tests/integration_tests/vector_store/test_weaviate.py b/tests/integration_tests/vector_store/test_weaviate.py index 188555f89..3cdfbf0c1 100644 --- a/tests/integration_tests/vector_store/test_weaviate.py +++ b/tests/integration_tests/vector_store/test_weaviate.py @@ -1,96 +1,67 @@ -import numpy as np -import pytest - -from superagi.vector_store import weaviate -from superagi.vector_store.document import Document -from superagi.vector_store.embedding.openai import OpenAiEmbedding - - -@pytest.fixture -def client(): - client = weaviate.create_weaviate_client(use_embedded=True) - yield client - client.schema.delete_all() - - -@pytest.fixture -def mock_openai_embedding(monkeypatch): - monkeypatch.setattr( - OpenAiEmbedding, - "get_embedding", - lambda self, text: np.random.random(3).tolist(), - ) - - -@pytest.fixture -def store(client, mock_openai_embedding): - client.schema.delete_all() - yield weaviate.Weaviate( - client, OpenAiEmbedding(api_key="test_api_key"), "Test_index", "text" - ) - - -@pytest.fixture -def dataset(): - book_titles = [ - "The Great Gatsby", - "To Kill a Mockingbird", - "1984", - "Pride and Prejudice", - "The Catcher in the Rye", - ] - - documents = [] - for i, title in enumerate(book_titles): - author = f"Author {i}" - description = f"A summary of {title}" - text_content = f"This is the text for {title}" - metadata = {"author": author, "description": description} - document = Document(text_content=text_content, metadata=metadata) - - documents.append(document) - - return documents - - -@pytest.fixture -def dataset_no_metadata(): - book_titles = [ - "The Lord of the Rings", - "The Hobbit", - "The Chronicles of Narnia", - ] - - documents = [] - for title in book_titles: - text_content = f"This is the text for {title}" - document = Document(text_content=text_content) - documents.append(document) - - return documents - - -@pytest.mark.parametrize( - "data, results", - [ - ("dataset", (5, 2)), - ("dataset_no_metadata", (3, 0)), - ], -) -def test_add_texts(store, data, results, request): - dataset = request.getfixturevalue(data) - count, num_metadata = results - ids = store.add_documents(dataset) - metadata_fields = store._get_metadata_fields() - assert len(ids) == count - assert len(metadata_fields) == num_metadata - - # manual cleanup because you will upload to the same index again - store.client.schema.delete_all() - - -def test_get_matching_text(store, dataset): - store.add_documents(dataset) - results = store.get_matching_text("The Great Gatsby", top_k=2) - assert len(results) == 2 - assert results[0] == dataset[0] +import unittest +from unittest.mock import Mock, patch, call, MagicMock +from superagi.vector_store.weaviate import create_weaviate_client, Weaviate, Document + +class TestWeaviateClient(unittest.TestCase): + @patch('weaviate.Client') + @patch('weaviate.AuthApiKey') + def test_create_weaviate_client(self, MockAuth, MockClient): + # Test when url and api_key are provided + auth_instance = MockAuth.return_value + MockClient.return_value = 'client' + self.assertEqual(create_weaviate_client('url', 'api_key'), 'client') + MockAuth.assert_called_once_with(api_key='api_key') + MockClient.assert_called_once_with(url='url', auth_client_secret=auth_instance) + + with self.assertRaises(ValueError): + create_weaviate_client() # Raises an error if no url is provided + +class TestWeaviate(unittest.TestCase): + + def setUp(self): + # create a new mock object for the client.batch attribute with the required methods for a context manager. + mock_batch = MagicMock() + mock_batch.__enter__.return_value = mock_batch + mock_batch.__exit__.return_value = None + + self.client = Mock() + self.client.batch = mock_batch + + self.embedding_model = Mock() + self.weaviateVectorStore = Weaviate(self.client, self.embedding_model, 'class_name', 'text_field') + + def test_get_matching_text(self): + self.client.query.get.return_value.with_near_vector.return_value.with_where.return_value.with_limit.return_value.do.return_value = {'data': {'Get': {'class_name': []}}} + self.embedding_model.get_embedding.return_value = 'vector' + self.weaviateVectorStore._get_metadata_fields = Mock(return_value=['field1', 'field2']) + self.weaviateVectorStore._get_search_res = Mock(return_value='search_res') + self.weaviateVectorStore._build_documents = Mock(return_value=['document1', 'document2']) + self.assertEqual(self.weaviateVectorStore.get_matching_text('query', metadata={'field1': 'value'}) + , {'search_res': 'search_res', 'documents': ['document1', 'document2']}) + self.embedding_model.get_embedding.assert_called_once_with('query') + + def test_add_texts(self): + self.embedding_model.get_embedding.return_value = 'vector' + self.weaviateVectorStore.add_embeddings_to_vector_db = Mock() + texts = ['text1', 'text2'] + result = self.weaviateVectorStore.add_texts(texts) + self.assertEqual(len(result), 2) # We expect to get 2 IDs. + self.assertTrue(isinstance(result[0], str)) # The IDs should be strings. + self.embedding_model.get_embedding.assert_has_calls([call(texts[0]), call(texts[1])]) + self.assertEqual(self.weaviateVectorStore.add_embeddings_to_vector_db.call_count, 2) + + def test_add_embeddings_to_vector_db(self): + embeddings = {'ids': ['id1', 'id2'], 'data_object': [{'field': 'value1'}, {'field': 'value2'}], 'vectors': ['v1', 'v2']} + self.weaviateVectorStore.add_embeddings_to_vector_db(embeddings) + calls = [call.add_data_object({'field': 'value1'}, class_name='class_name', uuid='id1', vector='v1'), + call.add_data_object({'field': 'value2'}, class_name='class_name', uuid='id2', vector='v2')] + + self.client.batch.assert_has_calls(calls) + + def test_delete_embeddings_from_vector_db(self): + # You need to setup appropriate return values from the Weaviate client + self.weaviateVectorStore.delete_embeddings_from_vector_db(['id1', 'id2']) + self.client.data_object.delete.assert_called() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 80bcbcf5884cc46291ef3d13faa4286cf0c9ecdc Mon Sep 17 00:00:00 2001 From: Tarraann Date: Fri, 11 Aug 2023 12:42:19 +0530 Subject: [PATCH 20/34] Test for vector embedding weaaviate --- .../vector_embeddings/test_weaviate.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/integration_tests/vector_embeddings/test_weaviate.py diff --git a/tests/integration_tests/vector_embeddings/test_weaviate.py b/tests/integration_tests/vector_embeddings/test_weaviate.py new file mode 100644 index 000000000..b05d763e3 --- /dev/null +++ b/tests/integration_tests/vector_embeddings/test_weaviate.py @@ -0,0 +1,25 @@ +import unittest +from superagi.vector_embeddings.base import VectorEmbeddings +from superagi.vector_embeddings.weaviate import Weaviate + +class TestWeaviate(unittest.TestCase): + + def setUp(self): + self.weaviate = Weaviate(uuid="1234", embeds=[0.1, 0.2, 0.3, 0.4], metadata={"info": "sample data"}) + + def test_init(self): + self.assertEqual(self.weaviate.uuid, "1234") + self.assertEqual(self.weaviate.embeds, [0.1, 0.2, 0.3, 0.4]) + self.assertEqual(self.weaviate.metadata, {"info": "sample data"}) + + def test_get_vector_embeddings_from_chunks(self): + expected_result = { + "ids": "1234", + "data_object": {"info": "sample data"}, + "vectors": [0.1, 0.2, 0.3, 0.4] + } + self.assertEqual(self.weaviate.get_vector_embeddings_from_chunks(), expected_result) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 268ed3db385de0b6e122ad8b076baa4a7e211ede Mon Sep 17 00:00:00 2001 From: Tarraann Date: Fri, 11 Aug 2023 12:45:09 +0530 Subject: [PATCH 21/34] Tests for vector embedding factory --- .../test_vector_embedding_factory.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py b/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py index de447c9c6..c9d35589e 100644 --- a/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py +++ b/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py @@ -6,23 +6,30 @@ class TestVectorEmbeddingFactory(unittest.TestCase): @patch("superagi.vector_embeddings.pinecone.Pinecone.__init__", return_value=None) @patch("superagi.vector_embeddings.qdrant.Qdrant.__init__", return_value=None) - def test_build_vector_storge(self, mock_qdrant, mock_pinecone): + @patch("superagi.vector_embeddings.weaviate.Weaviate.__init__", return_value=None) + def test_build_vector_storage(self, mock_weaviate, mock_qdrant, mock_pinecone): test_data = { "1": {"id": 1, "embeds": [1,2,3], "text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, "2": {"id": 2, "embeds": [4,5,6], "text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}, } - vector_storge = VectorEmbeddingFactory.build_vector_storage('Pinecone', test_data) + vector_storage = VectorEmbeddingFactory.build_vector_storage('Pinecone', test_data) mock_pinecone.assert_called_once_with( [1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}] ) - vector_storge = VectorEmbeddingFactory.build_vector_storage('Qdrant', test_data) + vector_storage = VectorEmbeddingFactory.build_vector_storage('Qdrant', test_data) mock_qdrant.assert_called_once_with( [1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}] ) + vector_storage = VectorEmbeddingFactory.build_vector_storage('Weaviate', test_data) + + mock_weaviate.assert_called_once_with( + [1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}] + ) + if __name__ == "__main__": unittest.main() \ No newline at end of file From 0c77bf154c28f8c7c7362477b401782ad6222653 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Fri, 11 Aug 2023 19:10:47 +0530 Subject: [PATCH 22/34] Schedule Agent Fix (#1043) --- superagi/models/agent_config.py | 11 ++++++----- .../controllers/test_update_agent_config_table.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/superagi/models/agent_config.py b/superagi/models/agent_config.py index f7af6de01..640b922e0 100644 --- a/superagi/models/agent_config.py +++ b/superagi/models/agent_config.py @@ -51,12 +51,12 @@ def update_agent_configurations_table(cls, session, agent_id: Union[int, None], ).first() if agent_toolkits_config: - agent_toolkits_config.value = updated_details_dict['toolkits'] + agent_toolkits_config.value = str(updated_details_dict['toolkits']) else: agent_toolkits_config = AgentConfiguration( agent_id=agent_id, key='toolkits', - value=updated_details_dict['toolkits'] + value=str(updated_details_dict['toolkits']) ) session.add(agent_toolkits_config) @@ -67,12 +67,12 @@ def update_agent_configurations_table(cls, session, agent_id: Union[int, None], ).first() if knowledge_config: - knowledge_config.value = updated_details_dict['knowledge'] + knowledge_config.value = str(updated_details_dict['knowledge']) else: knowledge_config = AgentConfiguration( agent_id=agent_id, key='knowledge', - value=updated_details_dict['knowledge'] + value=str(updated_details_dict['knowledge']) ) session.add(knowledge_config) @@ -80,12 +80,13 @@ def update_agent_configurations_table(cls, session, agent_id: Union[int, None], agent_configs = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all() for agent_config in agent_configs: if agent_config.key in updated_details_dict: - agent_config.value = updated_details_dict[agent_config.key] + agent_config.value = str(updated_details_dict[agent_config.key]) # Commit the changes to the database session.commit() return "Details updated successfully" + @classmethod def get_model_api_key(cls, session, agent_id: int, model: str): """ diff --git a/tests/unit_tests/controllers/test_update_agent_config_table.py b/tests/unit_tests/controllers/test_update_agent_config_table.py index 53297e7ad..632f59083 100644 --- a/tests/unit_tests/controllers/test_update_agent_config_table.py +++ b/tests/unit_tests/controllers/test_update_agent_config_table.py @@ -26,6 +26,6 @@ def test_update_existing_toolkits(): result = AgentConfiguration.update_agent_configurations_table(mock_session, agent_id, updated_details) #Check whether the value gets updated or not - assert existing_toolkits_config.value == [1, 2] + assert existing_toolkits_config.value == '[1, 2]' assert mock_session.commit.called_once() assert result == "Details updated successfully" From e90b59b4993db1ccf449ba39ad9e1eeb593d9ccd Mon Sep 17 00:00:00 2001 From: Fluder-Paradyne <121793617+Fluder-Paradyne@users.noreply.github.com> Date: Fri, 11 Aug 2023 20:00:15 +0530 Subject: [PATCH 23/34] add small check (#1046) --- superagi/controllers/agent_execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index ed6506a1f..743aae05b 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -102,7 +102,7 @@ def create_agent_execution(agent_execution: AgentExecutionIn, if agent_config.key not in keys_to_exclude: if agent_config.key == "toolkits": if agent_config.value: - toolkits = [int(item) for item in agent_config.value.strip('{}').split(',') if item.strip()] + toolkits = [int(item) for item in agent_config.value.strip('{}').split(',') if item.strip() and item != '[]'] agent_execution_configs[agent_config.key] = toolkits else: agent_execution_configs[agent_config.key] = [] From cf869a2ecdbdf88fef43b4f16657e59a020f046f Mon Sep 17 00:00:00 2001 From: namansleeps <122260931+namansleeps@users.noreply.github.com> Date: Sat, 12 Aug 2023 15:20:07 +0530 Subject: [PATCH 24/34] detail bug in schedule agent and naming of new run in edit agent (#1049) --- gui/pages/Content/Agents/ActivityFeed.js | 10 +++------- gui/pages/Content/Agents/AgentCreate.js | 4 +++- gui/pages/Content/Agents/AgentSchedule.js | 4 +++- gui/pages/Content/Agents/AgentWorkspace.js | 1 + gui/pages/Content/Agents/Details.js | 8 ++++++-- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/gui/pages/Content/Agents/ActivityFeed.js b/gui/pages/Content/Agents/ActivityFeed.js index 25258c7db..4dbfb24f5 100644 --- a/gui/pages/Content/Agents/ActivityFeed.js +++ b/gui/pages/Content/Agents/ActivityFeed.js @@ -165,12 +165,8 @@ export default function ActivityFeed({selectedRunId, selectedView, setFetchedDat } } - {feeds.length < 1 && !agent?.is_running && !agent?.is_scheduled ? - (isLoading ? -
- -
- :
The Agent is not scheduled
) : null + }}>The Agent is not scheduled } {feedContainerRef.current && feedContainerRef.current.scrollTop >= 1200 && diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index e54dcc072..1fb1f5d48 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -534,7 +534,9 @@ export default function AgentCreate({ setEditButtonClicked(true); agentData.agent_id = editAgentId; const name = agentData.name - agentData.name = "New Run" + const adjustedDate = new Date((new Date()).getTime() + 6*24*60*60*1000 - 1*60*1000); + const formattedDate = `${adjustedDate.getDate()} ${['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'][adjustedDate.getMonth()]} ${adjustedDate.getFullYear()} ${adjustedDate.getHours().toString().padStart(2, '0')}:${adjustedDate.getMinutes().toString().padStart(2, '0')}`; + agentData.name = "Run " + formattedDate addAgentRun(agentData) .then((response) => { if(response){ diff --git a/gui/pages/Content/Agents/AgentSchedule.js b/gui/pages/Content/Agents/AgentSchedule.js index 4c3f5e31d..76cd10003 100644 --- a/gui/pages/Content/Agents/AgentSchedule.js +++ b/gui/pages/Content/Agents/AgentSchedule.js @@ -166,8 +166,10 @@ export default function AgentSchedule({ const {schedule_id} = response.data; toast.success('Scheduled successfully!', {autoClose: 1800}); setCreateModal(); - EventBus.emit('refreshDate', {}); EventBus.emit('reFetchAgents', {}); + setTimeout(() => { + EventBus.emit('refreshDate', {}); + }, 1000) }) .catch(error => { console.error('Error:', error); diff --git a/gui/pages/Content/Agents/AgentWorkspace.js b/gui/pages/Content/Agents/AgentWorkspace.js index 2e3be1f9e..898bf6efb 100644 --- a/gui/pages/Content/Agents/AgentWorkspace.js +++ b/gui/pages/Content/Agents/AgentWorkspace.js @@ -74,6 +74,7 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a toast.success('Schedule stopped successfully!', {autoClose: 1800}); setCreateStopModal(false); EventBus.emit('reFetchAgents', {}); + setAgentScheduleDetails(null) } }) .catch((error) => { diff --git a/gui/pages/Content/Agents/Details.js b/gui/pages/Content/Agents/Details.js index 20489787f..a3c05dc17 100644 --- a/gui/pages/Content/Agents/Details.js +++ b/gui/pages/Content/Agents/Details.js @@ -9,7 +9,7 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails, const [showConstraints, setShowConstraints] = useState(false); const [showInstructions, setShowInstructions] = useState(false); const [filteredInstructions, setFilteredInstructions] = useState([]); - const [scheduleText, setScheduleText] = useState('Agent is not Scheduled'); + const [scheduleText, setScheduleText] = useState(''); const [agentDetails, setAgentDetails] = useState(null) const info_text = { marginLeft: '7px', @@ -37,6 +37,10 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails, }, [agentDetails1]); useEffect(() => { + if(!agentScheduleDetails){ + setScheduleText('') + return + } if (agent?.is_scheduled) { if (agentScheduleDetails?.recurrence_interval !== null) { if ((agentScheduleDetails?.expiry_runs === -1 || agentScheduleDetails?.expiry_runs == null) && agentScheduleDetails?.expiry_date !== null) { @@ -191,7 +195,7 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails,
info-icon
Stop after {agentDetails.max_iterations} iterations
} - {agent?.is_scheduled &&
+ {agent?.is_scheduled && scheduleText &&
info-icon
{scheduleText}
} From 9b8592b3fa497c66ffd38a977384aa0179d5e798 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 16 Aug 2023 13:15:33 +0530 Subject: [PATCH 25/34] PR Comments resolved --- superagi/controllers/vector_dbs.py | 3 ++- superagi/models/vector_db_indices.py | 2 +- superagi/vector_embeddings/weaviate.py | 6 +----- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/superagi/controllers/vector_dbs.py b/superagi/controllers/vector_dbs.py index 38c337a4b..49cd61b2d 100644 --- a/superagi/controllers/vector_dbs.py +++ b/superagi/controllers/vector_dbs.py @@ -139,9 +139,10 @@ def update_vector_db(new_indices: list, vector_db_id: int): vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, index, **db_creds) vector_db_index_stats = vector_db_storage.get_index_stats() index_state = "Custom" if vector_db_index_stats["vector_count"] > 0 else "None" + dimensions = vector_db_index_stats["dimensions"] if 'dimensions' in vector_db_index_stats else None except: raise HTTPException(status_code=400, detail="Unable to update vector db") - VectordbIndices.add_vector_index(db.session, index, vector_db_id, index_state, vector_db_index_stats["dimensions"]) + VectordbIndices.add_vector_index(db.session, index, vector_db_id, index_state, dimensions) diff --git a/superagi/models/vector_db_indices.py b/superagi/models/vector_db_indices.py index eaba8b669..f6c57fa8a 100644 --- a/superagi/models/vector_db_indices.py +++ b/superagi/models/vector_db_indices.py @@ -47,7 +47,7 @@ def delete_vector_db_index(cls, session, vector_index_id): session.commit() @classmethod - def add_vector_index(cls, session, index_name, vector_db_id, state, dimensions = None): + def add_vector_index(cls, session, index_name, vector_db_id, state, dimensions = None): #will be none only in the case of weaviate vector_index = VectordbIndices(name=index_name, vector_db_id=vector_db_id, dimensions=dimensions, state=state) session.add(vector_index) session.commit() diff --git a/superagi/vector_embeddings/weaviate.py b/superagi/vector_embeddings/weaviate.py index 9d77bd15b..13f61d719 100644 --- a/superagi/vector_embeddings/weaviate.py +++ b/superagi/vector_embeddings/weaviate.py @@ -10,9 +10,5 @@ def __init__(self, uuid, embeds, metadata): def get_vector_embeddings_from_chunks(self): """ Returns embeddings for vector dbs from final chunks""" - result = {} - result['ids'] = self.uuid - result['data_object'] = self.metadata - result['vectors'] = self.embeds - return result \ No newline at end of file + return {'id': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds} \ No newline at end of file From 77455cec5c93a9c171f9c7aad6c45de163a71224 Mon Sep 17 00:00:00 2001 From: lalit-contlo <138583454+lalit-contlo@users.noreply.github.com> Date: Wed, 16 Aug 2023 14:24:58 +0530 Subject: [PATCH 26/34] External tools, marketplace tools fix (#1058) --- superagi/agent/tool_builder.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/superagi/agent/tool_builder.py b/superagi/agent/tool_builder.py index 33a0395de..1801047a1 100644 --- a/superagi/agent/tool_builder.py +++ b/superagi/agent/tool_builder.py @@ -1,5 +1,5 @@ import importlib - +import os from superagi.config.config import get_config from superagi.llms.llm_model_factory import get_model from superagi.models.tool import Tool @@ -58,9 +58,11 @@ def build_tool(self, tool: Tool): """ file_name = self.__validate_filename(filename=tool.file_name) - tools_dir = get_config("TOOLS_DIR") - if tools_dir is None: - tools_dir = "superagi/tools" + tool_paths = ["superagi/tools", "superagi/tools/external_tools", "superagi/tools/marketplace_tools"] + for tool_path in tool_paths: + if os.path.exists(os.path.join(os.getcwd(), tool_path) + '/' + tool.folder_name): + tools_dir = tool_path + break parsed_tools_dir = tools_dir.rstrip("/") module_name = ".".join(parsed_tools_dir.split("/") + [tool.folder_name, file_name]) From 4c818453854ec5a494f38394296e1468464d4eb3 Mon Sep 17 00:00:00 2001 From: Tarraann Date: Wed, 16 Aug 2023 18:07:10 +0530 Subject: [PATCH 27/34] MInor fix --- superagi/vector_embeddings/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superagi/vector_embeddings/weaviate.py b/superagi/vector_embeddings/weaviate.py index 13f61d719..80d8f86b6 100644 --- a/superagi/vector_embeddings/weaviate.py +++ b/superagi/vector_embeddings/weaviate.py @@ -11,4 +11,4 @@ def __init__(self, uuid, embeds, metadata): def get_vector_embeddings_from_chunks(self): """ Returns embeddings for vector dbs from final chunks""" - return {'id': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds} \ No newline at end of file + return {'ids': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds} \ No newline at end of file From e14b118aa5309a75449117f211bf96c98f4a9606 Mon Sep 17 00:00:00 2001 From: namansleeps <122260931+namansleeps@users.noreply.github.com> Date: Wed, 16 Aug 2023 18:40:07 +0530 Subject: [PATCH 28/34] bug fixes of weaviate and profile dropdown bug (#1062) --- docker-compose-dev.yaml | 8 ++++---- gui/pages/Content/Agents/AgentCreate.js | 2 +- gui/pages/Dashboard/Settings/AddDatabase.js | 2 +- gui/pages/Dashboard/TopBar.js | 4 ++-- gui/pages/_app.css | 6 ++++++ 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml index 94044916b..d3caf6029 100644 --- a/docker-compose-dev.yaml +++ b/docker-compose-dev.yaml @@ -28,10 +28,10 @@ services: NEXT_PUBLIC_API_BASE_URL: "/api" networks: - super_network -# volumes: -# - ./gui:/app -# - /app/node_modules/ -# - /app/.next/ + volumes: + - ./gui:/app + - /app/node_modules/ + - /app/.next/ super__redis: image: "redis/redis-stack-server:latest" networks: diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 1fb1f5d48..3f7f62b30 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -534,7 +534,7 @@ export default function AgentCreate({ setEditButtonClicked(true); agentData.agent_id = editAgentId; const name = agentData.name - const adjustedDate = new Date((new Date()).getTime() + 6*24*60*60*1000 - 1*60*1000); + const adjustedDate = new Date((new Date()).getTime()); const formattedDate = `${adjustedDate.getDate()} ${['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'][adjustedDate.getMonth()]} ${adjustedDate.getFullYear()} ${adjustedDate.getHours().toString().padStart(2, '0')}:${adjustedDate.getMinutes().toString().padStart(2, '0')}`; agentData.name = "Run " + formattedDate addAgentRun(agentData) diff --git a/gui/pages/Dashboard/Settings/AddDatabase.js b/gui/pages/Dashboard/Settings/AddDatabase.js index 7e8a19c9e..bc5efdc67 100644 --- a/gui/pages/Dashboard/Settings/AddDatabase.js +++ b/gui/pages/Dashboard/Settings/AddDatabase.js @@ -332,7 +332,7 @@ export default function AddDatabase({internalId, sendDatabaseDetailsData}) {
-
+
{selectedDB === 'Weaviate' ? : }
{collections.map((collection, index) => (
{dropdown && env === 'PROD' && -
setDropdown(true)} +
setDropdown(true)} onMouseLeave={() => setDropdown(false)}>
    -
  • setDropdown(false)}>{userName}
  • + {userName &&
  • setDropdown(false)}>{userName}
  • }
  • Logout
} diff --git a/gui/pages/_app.css b/gui/pages/_app.css index c78b07085..8c6a78a34 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -1673,3 +1673,9 @@ tr{ .history_box_selected{ background: #474255; } + +.top_bar_profile_dropdown{ + display: flex; + flex-direction: row; + justify-content: center; +} From 240cccbeba0fd15b8ec594cd09cf6760b6f0a9dc Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Thu, 17 Aug 2023 10:55:15 +0530 Subject: [PATCH 29/34] New run fix (#1055) * Schedule agent fix * Schedule agent fix * Update agent template fix * new run fix --------- Co-authored-by: Rounak Bhatia --- superagi/controllers/agent_execution.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index 743aae05b..1a485562c 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -108,8 +108,7 @@ def create_agent_execution(agent_execution: AgentExecutionIn, agent_execution_configs[agent_config.key] = [] elif agent_config.key == "constraints": if agent_config.value: - constraints = [item.strip('"') for item in agent_config.value.strip('{}').split(',')] - agent_execution_configs[agent_config.key] = constraints + agent_execution_configs[agent_config.key] = agent_config.value else: agent_execution_configs[agent_config.key] = [] else: From a9a8392eebf2064a5f2bc83bf6351a93902b274d Mon Sep 17 00:00:00 2001 From: lalit-contlo <138583454+lalit-contlo@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:07:04 +0530 Subject: [PATCH 30/34] Test tool builder fix (#1067) --- superagi/agent/tool_builder.py | 1 + tests/unit_tests/agent/test_tool_builder.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/superagi/agent/tool_builder.py b/superagi/agent/tool_builder.py index 1801047a1..eb2194468 100644 --- a/superagi/agent/tool_builder.py +++ b/superagi/agent/tool_builder.py @@ -58,6 +58,7 @@ def build_tool(self, tool: Tool): """ file_name = self.__validate_filename(filename=tool.file_name) + tools_dir="" tool_paths = ["superagi/tools", "superagi/tools/external_tools", "superagi/tools/marketplace_tools"] for tool_path in tool_paths: if os.path.exists(os.path.join(os.getcwd(), tool_path) + '/' + tool.folder_name): diff --git a/tests/unit_tests/agent/test_tool_builder.py b/tests/unit_tests/agent/test_tool_builder.py index d5be7fc78..874fcc3f9 100644 --- a/tests/unit_tests/agent/test_tool_builder.py +++ b/tests/unit_tests/agent/test_tool_builder.py @@ -43,7 +43,7 @@ def test_build_tool(mock_getattr, mock_import_module, tool_builder, tool): result_tool = tool_builder.build_tool(tool) - mock_import_module.assert_called_with('superagi.tools.test_folder.test') + mock_import_module.assert_called_with('.test_folder.test') mock_getattr.assert_called_with(mock_module, tool.class_name) assert result_tool.toolkit_config.session == tool_builder.session From f8e9127b6fe97058e87adfc9078d5c409d54ebac Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Thu, 17 Aug 2023 15:21:50 +0530 Subject: [PATCH 31/34] removing apollo from superagi workflow --- superagi/agent/workflow_seed.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/superagi/agent/workflow_seed.py b/superagi/agent/workflow_seed.py index fb9c4b03e..0bb4694c9 100644 --- a/superagi/agent/workflow_seed.py +++ b/superagi/agent/workflow_seed.py @@ -31,21 +31,22 @@ class AgentWorkflowSeed: def build_sales_workflow(cls, session): agent_workflow = AgentWorkflow.find_or_create_by_name(session, "Sales Engagement Workflow", "Sales Engagement Workflow") - step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, - str(agent_workflow.id) + "_step1", - ApolloSearchTool().name, - "Search for leads based on the given goals", - step_type="TRIGGER") - - step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, - str(agent_workflow.id) + "_step2", - WriteFileTool().name, - "Write the leads to a csv file") + # step1 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, + # str(agent_workflow.id) + "_step1", + # ApolloSearchTool().name, + # "Search for leads based on the given goals", + # step_type="TRIGGER") + # + # step2 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, + # str(agent_workflow.id) + "_step2", + # WriteFileTool().name, + # "Write the leads to a csv file") step3 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, str(agent_workflow.id) + "_step3", ReadFileTool().name, - "Read the leads from the file generated in the previous run") + "Read the leads from the file generated in the previous run", + step_type="TRIGGER") # task queue ends when the elements gets over step4 = AgentWorkflowStep.find_or_create_tool_workflow_step(session, agent_workflow.id, From 96be7ed59da746e97d1ecb376460de051cd5d9ed Mon Sep 17 00:00:00 2001 From: TransformerOptimus Date: Thu, 17 Aug 2023 15:28:41 +0530 Subject: [PATCH 32/34] removing apollo from superagi workflow --- superagi/agent/workflow_seed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superagi/agent/workflow_seed.py b/superagi/agent/workflow_seed.py index 0bb4694c9..b888c6547 100644 --- a/superagi/agent/workflow_seed.py +++ b/superagi/agent/workflow_seed.py @@ -75,8 +75,8 @@ def build_sales_workflow(cls, session): SendEmailTool().name, "Customize the Email according to the company information in the mail") - AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id) - AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id) + # AgentWorkflowStep.add_next_workflow_step(session, step1.id, step2.id) + # AgentWorkflowStep.add_next_workflow_step(session, step2.id, step3.id) AgentWorkflowStep.add_next_workflow_step(session, step3.id, step4.id) AgentWorkflowStep.add_next_workflow_step(session, step4.id, -1, "COMPLETE") AgentWorkflowStep.add_next_workflow_step(session, step4.id, step5.id) From 846d5f523dd312f46c7243307bc8be627558eb96 Mon Sep 17 00:00:00 2001 From: Maverick-F35 <138012351+Maverick-F35@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:31:47 +0530 Subject: [PATCH 33/34] Super agi api (#992) --- .github/workflows/ci.yml | 1 + gui/pages/Content/Agents/Agents.module.css | 31 ++ .../Content/Marketplace/Market.module.css | 14 + gui/pages/Dashboard/Settings/ApiKeys.js | 286 +++++++++++++++ gui/pages/Dashboard/Settings/Settings.js | 8 + gui/pages/_app.css | 25 ++ gui/pages/api/DashboardService.js | 17 + gui/public/images/copy_icon.svg | 3 + gui/public/images/key_white.svg | 3 + main.py | 9 +- .../446884dcae58_add_api_key_and_web_hook.py | 65 ++++ superagi/controllers/api/agent.py | 326 ++++++++++++++++++ superagi/controllers/api_key.py | 55 +++ .../controllers/types/agent_with_config.py | 44 ++- superagi/controllers/webhook.py | 60 ++++ superagi/helper/auth.py | 39 ++- superagi/helper/s3_helper.py | 28 +- superagi/helper/webhook_manager.py | 37 ++ superagi/models/agent.py | 22 +- superagi/models/agent_execution.py | 47 ++- superagi/models/agent_schedule.py | 25 +- superagi/models/api_key.py | 46 +++ superagi/models/project.py | 10 + superagi/models/resource.py | 7 +- superagi/models/toolkit.py | 23 ++ superagi/models/webhook_events.py | 25 ++ superagi/models/webhooks.py | 22 ++ superagi/worker.py | 21 +- tests/unit_tests/controllers/api/__init__.py | 0 .../unit_tests/controllers/api/test_agent.py | 220 ++++++++++++ tests/unit_tests/models/test_agent.py | 21 ++ .../unit_tests/models/test_agent_execution.py | 27 +- .../unit_tests/models/test_agent_schedule.py | 23 ++ tests/unit_tests/models/test_api_key.py | 93 +++++ tests/unit_tests/models/test_project.py | 42 +++ tests/unit_tests/models/test_toolkit.py | 23 +- 36 files changed, 1725 insertions(+), 23 deletions(-) create mode 100644 gui/pages/Dashboard/Settings/ApiKeys.js create mode 100644 gui/public/images/copy_icon.svg create mode 100644 gui/public/images/key_white.svg create mode 100644 migrations/versions/446884dcae58_add_api_key_and_web_hook.py create mode 100644 superagi/controllers/api/agent.py create mode 100644 superagi/controllers/api_key.py create mode 100644 superagi/controllers/webhook.py create mode 100644 superagi/helper/webhook_manager.py create mode 100644 superagi/models/api_key.py create mode 100644 superagi/models/webhook_events.py create mode 100644 superagi/models/webhooks.py create mode 100644 tests/unit_tests/controllers/api/__init__.py create mode 100644 tests/unit_tests/controllers/api/test_agent.py create mode 100644 tests/unit_tests/models/test_agent_schedule.py create mode 100644 tests/unit_tests/models/test_api_key.py create mode 100644 tests/unit_tests/models/test_project.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b94860d37..8bbd7a241 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,6 +91,7 @@ jobs: ENV: DEV PLAIN_OUTPUT: True REDIS_URL: "localhost:6379" + IS_TESTING: True - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/gui/pages/Content/Agents/Agents.module.css b/gui/pages/Content/Agents/Agents.module.css index 62d782ef9..c6553a8f9 100644 --- a/gui/pages/Content/Agents/Agents.module.css +++ b/gui/pages/Content/Agents/Agents.module.css @@ -429,4 +429,35 @@ color: #888888 !important; text-decoration: line-through; pointerEvents: none !important; +} + +.modal_buttons{ + display: flex; + justify-content: flex-end; + margin-top: 20px +} + +.modal_info_class{ + margin-left: -5px; + margin-right: 5px; +} + +.table_contents{ + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + margin-top: 40px; + width: 100% +} + +.create_settings_button{ + display: flex; + justify-content: center; + align-items: center; + margin-top: 10px +} + +.button_margin{ + margin-top: -10px; } \ No newline at end of file diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css index 545146d52..4bd162ade 100644 --- a/gui/pages/Content/Marketplace/Market.module.css +++ b/gui/pages/Content/Marketplace/Market.module.css @@ -515,3 +515,17 @@ overflow-y:scroll; overflow-x:hidden; } + +.settings_tab_button_clicked{ + background: #454254; + padding-right: 15px +} + +.settings_tab_button{ + background: transparent; + padding-right: 15px +} + +.settings_tab_img{ + margin-top: -1px; +} diff --git a/gui/pages/Dashboard/Settings/ApiKeys.js b/gui/pages/Dashboard/Settings/ApiKeys.js new file mode 100644 index 000000000..61654b4ca --- /dev/null +++ b/gui/pages/Dashboard/Settings/ApiKeys.js @@ -0,0 +1,286 @@ +import React, {useState, useEffect, useRef} from 'react'; +import {ToastContainer, toast} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import agentStyles from "@/pages/Content/Agents/Agents.module.css"; +import { + createApiKey, deleteApiKey, + editApiKey, getApiKeys, +} from "@/pages/api/DashboardService"; +import {EventBus} from "@/utils/eventBus"; +import {createInternalId, loadingTextEffect, preventDefault, removeTab, returnToolkitIcon} from "@/utils/utils"; +import Image from "next/image"; +import styles from "@/pages/Content/Marketplace/Market.module.css"; +import styles1 from "@/pages/Content/Knowledge/Knowledge.module.css"; + +export default function ApiKeys() { + const [apiKeys, setApiKeys] = useState([]); + const [keyName, setKeyName] = useState(''); + const [editKey, setEditKey] = useState(''); + const apiKeyRef = useRef(null); + const editKeyRef = useRef(null); + const [editKeyId, setEditKeyId] = useState(-1); + const [deleteKey, setDeleteKey] = useState('') + const [isLoading, setIsLoading] = useState(true) + const [activeDropdown, setActiveDropdown] = useState(null); + const [editModal, setEditModal] = useState(false); + const [deleteKeyId, setDeleteKeyId] = useState(-1); + const [deleteModal, setDeleteModal] = useState(false); + const [createModal, setCreateModal] = useState(false); + const [displayModal, setDisplayModal] = useState(false); + const [apiKeyGenerated, setApiKeyGenerated] = useState(''); + const [loadingText, setLoadingText] = useState("Loading Api Keys"); + + + + useEffect(() => { + loadingTextEffect('Loading Api Keys', setLoadingText, 500); + fetchApiKeys() + }, []); + + + const handleModelApiKey = (event) => { + setKeyName(event.target.value); + }; + + const handleEditApiKey = (event) => { + setEditKey(event.target.value); + }; + + const createApikey = () => { + if(!keyName){ + toast.error("Enter key name", {autoClose: 1800}); + return; + } + createApiKey({name : keyName}) + .then((response) => { + setApiKeyGenerated(response.data.api_key) + toast.success("Api Key Generated", {autoClose: 1800}); + setCreateModal(false); + setDisplayModal(true); + fetchApiKeys(); + }) + .catch((error) => { + console.error('Error creating api key', error); + }); + } + const handleCopyClick = async () => { + if (apiKeyRef.current) { + try { + await navigator.clipboard.writeText(apiKeyRef.current.value); + toast.success("Key Copied", {autoClose: 1800}); + } catch (err) { + toast.error('Failed to Copy', {autoClose: 1800}); + } + } + }; + + const fetchApiKeys = () => { + getApiKeys() + .then((response) => { + const formattedData = response.data.map(item => { + return { + ...item, + created_at: `${new Date(item.created_at).getDate()}-${["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"][new Date(item.created_at).getMonth()]}-${new Date(item.created_at).getFullYear()}` + }; + }); + setApiKeys(formattedData) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Api Keys', error); + }); + } + + const handleEditClick = () => { + if(editKeyRef.current.value.length <1){ + toast.error("Enter valid key name", {autoClose: 1800}); + return; + } + editApiKey({id: editKeyId,name : editKey}) + .then((response) => { + toast.success("Api Key Edited", {autoClose: 1800}); + fetchApiKeys(); + setEditModal(false); + setEditKey('') + setEditKeyId(-1) + }) + .catch((error) => { + console.error('Error editing api key', error); + }); + } + + const handleDeleteClick = () => { + deleteApiKey(deleteKeyId) + .then((response) => { + toast.success("Api Key Deleted", {autoClose: 1800}); + fetchApiKeys(); + setDeleteModal(false); + setDeleteKeyId(-1) + setDeleteKey('') + }) + .catch((error) => { + toast.error("Error deleting api key", {autoClose: 1800}); + console.error('Error deleting api key', error); + }); + } + + return (<> +
+
+
+ {!isLoading ?
+
+
API Keys
+ {apiKeys && apiKeys.length > 0 && !isLoading && + } +
+
+ + + + {apiKeys.length < 1 &&
+ no-permissions + No API Keys created! +
+ +
+
} + + {apiKeys.length > 0 &&
+ + + + + + + + + +
NameKeyCreated Date
+
+ + + {apiKeys.map((item, index) => ( + + + + + + + ))} + +
{item.name}{item.key.slice(0, 2) + "****" + item.key.slice(-4)}{item.created_at} setActiveDropdown(null)} onClick={() => { + if (activeDropdown === index) { + setActiveDropdown(null); + } else { + setActiveDropdown(index); + } + }}> + run-icon +
setActiveDropdown(null)}> +
    +
  • {setEditKey(item.name); setEditKeyId(item.id); setEditModal(true); setActiveDropdown(null);}}>Edit
  • +
  • {setDeleteKeyId(item.id); setDeleteKey(item.name) ; setDeleteModal(true); setActiveDropdown(null);}}>Delete
  • +
+
+
} +
+
:
+
{loadingText}
+
} +
+
+
+ + {createModal && (
setCreateModal(false)}> +
+
Create new API Key
+
+ + +
+
+ + +
+
+
)} + + {displayModal && apiKeyGenerated && (
setDisplayModal(false)}> +
+
{keyName} is created
+
+
+
+
+ info-icon +
+
+ Your secret API keys are sensitive pieces of information that should be kept confidential. Do not share them with anyone, and do not expose them in any way. If your secret API keys are compromised, someone could use them to access your API and make unauthorized changes to your data. This secret key is only displayed once for security reasons. Please save it in a secure location where you can access it easily. +
+
+
+
+
+
+
+
+ +
+
+
+
+
+ +
+
+
)} + + {editModal && (
{setEditModal(false); setEditKey(''); setEditKeyId(-1)}}> +
+
Edit API Key
+
+ + +
+
+ + +
+
+
)} + + {deleteModal && (
{setDeleteModal(false); setDeleteKeyId(-1); setDeleteKey('')}}> +
+
Delete {deleteKey} Key
+
+ +
+
+ + +
+
+
)} + + ) +} \ No newline at end of file diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js index 297e6a3d5..402fc6050 100644 --- a/gui/pages/Dashboard/Settings/Settings.js +++ b/gui/pages/Dashboard/Settings/Settings.js @@ -4,6 +4,7 @@ import styles from "@/pages/Content/Marketplace/Market.module.css"; import Image from "next/image"; import Model from "@/pages/Dashboard/Settings/Model"; import Database from "@/pages/Dashboard/Settings/Database"; +import ApiKeys from "@/pages/Dashboard/Settings/ApiKeys"; export default function Settings({organisationId, sendDatabaseData}) { const [activeTab, setActiveTab] = useState('model'); @@ -44,11 +45,18 @@ export default function Settings({organisationId, sendDatabaseData}) { alt="database-icon"/> Database
+
+ +
{activeTab === 'model' && } {activeTab === 'database' && } + {activeTab === 'apikeys' && }
diff --git a/gui/pages/_app.css b/gui/pages/_app.css index 8c6a78a34..7be317c86 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -997,13 +997,16 @@ p { .r_0{right: 0} .w_120p{width: 120px} +.w_4{width: 4%} .w_6{width: 6%} .w_10{width: 10%} .w_12{width: 12%} +.w_18{width: 18%} .w_20{width: 20%} .w_22{width: 22%} .w_35{width: 35%} .w_56{width: 56%} +.w_60{width: 60%} .w_100{width: 100%} .w_inherit{width: inherit} .w_fit_content{width:fit-content} @@ -1052,6 +1055,7 @@ p { .gap_16{gap:16px;} .gap_20{gap:20px;} +.border_top_none{border-top: none;} .border_radius_8{border-radius: 8px;} .border_radius_25{border-radius: 25px;} @@ -1072,12 +1076,15 @@ p { .padding_0_8{padding: 0px 8px;} .padding_0_15{padding: 0px 15px;} +.flex_1{flex: 1} .flex_wrap{flex-wrap: wrap;} .mix_blend_mode{mix-blend-mode: exclusion;} .ff_sourceCode{font-family: 'Source Code Pro'} +.rotate_90{transform: rotate(90deg)} + /*------------------------------- My ROWS AND COLUMNS -------------------------------*/ .my_rows { display: flex; @@ -1674,6 +1681,24 @@ tr{ background: #474255; } +.loading_container{ + display: flex; + justify-content: center; + align-items: center; + height: 50vh +} + +.loading_text{ + font-size: 16px; + font-family: 'Source Code Pro'; +} + +.table_container{ + background: #272335; + border-radius: 8px; + margin-top:15px +} + .top_bar_profile_dropdown{ display: flex; flex-direction: row; diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index ed9b0a1ef..625c6c2af 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -303,3 +303,20 @@ export const fetchKnowledgeTemplateOverview = (knowledgeName) => { export const installKnowledgeTemplate = (knowledgeName, indexId) => { return api.get(`/knowledges/install/${knowledgeName}/index/${indexId}`); }; + +export const createApiKey = (apiName) => { + return api.post(`/api-keys`, apiName); +}; + +export const getApiKeys = () => { + return api.get(`/api-keys`); +}; + +export const editApiKey = (apiDetails) => { + return api.put(`/api-keys`, apiDetails); +}; + +export const deleteApiKey = (apiId) => { + return api.delete(`/api-keys/${apiId}`); +}; + diff --git a/gui/public/images/copy_icon.svg b/gui/public/images/copy_icon.svg new file mode 100644 index 000000000..46644e0dc --- /dev/null +++ b/gui/public/images/copy_icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/gui/public/images/key_white.svg b/gui/public/images/key_white.svg new file mode 100644 index 000000000..9b35b92e0 --- /dev/null +++ b/gui/public/images/key_white.svg @@ -0,0 +1,3 @@ + + + diff --git a/main.py b/main.py index f6227a179..4c698eb9a 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,9 @@ from superagi.controllers.vector_dbs import router as vector_dbs_router from superagi.controllers.vector_db_indices import router as vector_db_indices_router from superagi.controllers.marketplace_stats import router as marketplace_stats_router +from superagi.controllers.api_key import router as api_key_router +from superagi.controllers.api.agent import router as api_agent_router +from superagi.controllers.webhook import router as web_hook_router from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits from superagi.lib.logger import logger from superagi.llms.google_palm import GooglePalm @@ -50,7 +53,6 @@ from superagi.models.workflows.agent_workflow import AgentWorkflow from superagi.models.workflows.iteration_workflow import IterationWorkflow from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep - app = FastAPI() database_url = get_config('POSTGRES_URL') @@ -113,7 +115,9 @@ app.include_router(vector_dbs_router, prefix="/vector_dbs") app.include_router(vector_db_indices_router, prefix="/vector_db_indices") app.include_router(marketplace_stats_router, prefix="/marketplace") - +app.include_router(api_key_router, prefix="/api-keys") +app.include_router(api_agent_router,prefix="/v1/agent") +app.include_router(web_hook_router,prefix="/webhook") # in production you can use Settings management # from pydantic to get secret key from .env @@ -370,3 +374,4 @@ def github_client_id(): # # __________________TO RUN____________________________ # # uvicorn main:app --host 0.0.0.0 --port 8001 --reload + diff --git a/migrations/versions/446884dcae58_add_api_key_and_web_hook.py b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py new file mode 100644 index 000000000..c4b353756 --- /dev/null +++ b/migrations/versions/446884dcae58_add_api_key_and_web_hook.py @@ -0,0 +1,65 @@ +"""add api_key and web_hook + +Revision ID: 446884dcae58 +Revises: 71e3980d55f5 +Create Date: 2023-07-29 10:55:21.714245 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '446884dcae58' +down_revision = '2fbd6472112c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('api_keys', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=True), + sa.Column('name', sa.String(), nullable=True), + sa.Column('key', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('is_expired',sa.Boolean(),nullable=True,default=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('webhooks', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('org_id', sa.Integer(), nullable=True), + sa.Column('url', sa.String(), nullable=True), + sa.Column('headers', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('is_deleted',sa.Boolean(),nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('webhook_events', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('agent_id', sa.Integer(), nullable=True), + sa.Column('run_id', sa.Integer(), nullable=True), + sa.Column('event', sa.String(), nullable=True), + sa.Column('status', sa.String(), nullable=True), + sa.Column('errors', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + #add index ********************* + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('webhooks') + op.drop_table('api_keys') + op.drop_table('webhook_events') + + # ### end Alembic commands ### diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py new file mode 100644 index 000000000..a71e60a8f --- /dev/null +++ b/superagi/controllers/api/agent.py @@ -0,0 +1,326 @@ +from fastapi import APIRouter +from fastapi import HTTPException, Depends ,Security + +from fastapi_sqlalchemy import db +from pydantic import BaseModel + +from superagi.worker import execute_agent +from superagi.helper.auth import validate_api_key,get_organisation_from_api_key +from superagi.models.agent import Agent +from superagi.models.agent_execution_config import AgentExecutionConfiguration +from superagi.models.agent_config import AgentConfiguration +from superagi.models.agent_schedule import AgentSchedule +from superagi.models.project import Project +from superagi.models.workflows.agent_workflow import AgentWorkflow +from superagi.models.agent_execution import AgentExecution +from superagi.models.organisation import Organisation +from superagi.models.resource import Resource +from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput +from superagi.models.workflows.iteration_workflow import IterationWorkflow +from superagi.helper.s3_helper import S3Helper +from datetime import datetime +from typing import Optional,List +from superagi.models.toolkit import Toolkit +from superagi.apm.event_handler import EventHandler +from superagi.config.config import get_config +router = APIRouter() + +class AgentExecutionIn(BaseModel): + name: Optional[str] + goal: Optional[List[str]] + instruction: Optional[List[str]] + + class Config: + orm_mode = True + +class RunFilterConfigIn(BaseModel): + run_ids:Optional[List[int]] + run_status_filter:Optional[str] + + class Config: + orm_mode = True + +class ExecutionStateChangeConfigIn(BaseModel): + run_ids:Optional[List[int]] + + class Config: + orm_mode = True + +class RunIDConfig(BaseModel): + run_ids:List[int] + + class Config: + orm_mode = True + +@router.post("", status_code=200) +def create_agent_with_config(agent_with_config: AgentConfigExtInput, + api_key: str = Security(validate_api_key), organisation:Organisation = Depends(get_organisation_from_api_key)): + project=Project.find_by_org_id(db.session, organisation.id) + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404, detail=str(e)) + + agent_with_config.tools=tools_arr + agent_with_config.project_id=project.id + agent_with_config.exit="No exit criterion" + agent_with_config.permission_type="God Mode" + agent_with_config.LTM_DB=None + db_agent = Agent.create_agent_with_config(db, agent_with_config) + + if agent_with_config.schedule is not None: + agent_schedule = AgentSchedule.save_schedule_from_config(db.session, db_agent, agent_with_config.schedule) + if agent_schedule is None: + raise HTTPException(status_code=500, detail="Failed to schedule agent") + EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, + 'model': agent_with_config.model}, db_agent.id, + organisation.id if organisation else 0) + db.session.commit() + return { + "agent_id": db_agent.id + } + + start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id) + iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, + start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 + # Creating an execution with RUNNING status + execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id, + name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id) + agent_execution_configs = { + "goal": agent_with_config.goal, + "instruction": agent_with_config.instruction + } + db.session.add(execution) + db.session.commit() + db.session.flush() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution, + agent_execution_configs=agent_execution_configs) + + organisation = db_agent.get_agent_organisation(db.session) + EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, + 'model': agent_with_config.model}, db_agent.id, + organisation.id if organisation else 0) + # execute_agent.delay(execution.id, datetime.now()) + db.session.commit() + return { + "agent_id": db_agent.id + } + +@router.post("/{agent_id}/run",status_code=200) +def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent=Agent.get_agent_from_id(db.session,agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id) + if db_schedule is not None: + raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run") + start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id) + db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED") + + if db_agent_execution is None: + db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), + agent_id=agent_id, name=agent_execution.name, num_of_calls=0, + num_of_tokens=0, + current_step_id=start_step_id) + db.session.add(db_agent_execution) + else: + db_agent_execution.status = "RUNNING" + + db.session.commit() + db.session.flush() + + agent_execution_configs = {} + if agent_execution.goal is not None: + agent_execution_configs = { + "goal": agent_execution.goal, + } + + if agent_execution.instruction is not None: + agent_execution_configs["instructions"] = agent_execution.instruction, + + if agent_execution_configs != {}: + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution, + agent_execution_configs=agent_execution_configs) + EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, + agent_id, organisation.id if organisation else 0) + + if db_agent_execution.status == "RUNNING": + execute_agent.delay(db_agent_execution.id, datetime.now()) + return { + "run_id":db_agent_execution.id + } + +@router.put("/{agent_id}",status_code=200) +def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api_key: str = Security(validate_api_key), + organisation:Organisation = Depends(get_organisation_from_api_key)): + + db_agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not db_agent: + raise HTTPException(status_code=404, detail="agent not found") + + project=Project.find_by_id(db.session, db_agent.project_id) + if project is None: + raise HTTPException(status_code=404, detail="Project not found") + + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + db_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "RUNNING") + if db_execution is not None: + raise HTTPException(status_code=409, detail="Agent is already running,please pause and then update") + + db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id) + if db_schedule is not None: + raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update") + + try: + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + except Exception as e: + raise HTTPException(status_code=404,detail=str(e)) + + if agent_with_config.schedule is not None: + raise HTTPException(status_code=400,detail="Cannot schedule an existing agent") + agent_with_config.tools=tools_arr + agent_with_config.project_id=project.id + agent_with_config.exit="No exit criterion" + agent_with_config.permission_type="God Mode" + agent_with_config.LTM_DB=None + + for key,value in agent_with_config.dict().items(): + if hasattr(db_agent,key) and value is not None: + setattr(db_agent,key,value) + db.session.commit() + db.session.flush() + + start_step = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id) + iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, + start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 + execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id, + name="New Run", current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id) + agent_execution_configs = { + "goal": agent_with_config.goal, + "instruction": agent_with_config.instruction, + "tools":agent_with_config.tools, + "constraints": agent_with_config.constraints, + "iteration_interval": agent_with_config.iteration_interval, + "model": agent_with_config.model, + "max_iterations": agent_with_config.max_iterations, + "agent_workflow": agent_with_config.agent_workflow, + } + agent_configurations = [ + AgentConfiguration(agent_id=db_agent.id, key=key, value=str(value)) + for key, value in agent_execution_configs.items() + ] + db.session.add_all(agent_configurations) + db.session.add(execution) + db.session.commit() + db.session.flush() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution, + agent_execution_configs=agent_execution_configs) + db.session.commit() + + return { + "agent_id":db_agent.id + } + + +@router.post("/{agent_id}/run-status") +def get_agent_runs(agent_id:int,filter_config:RunFilterConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + db_execution_arr=[] + if filter_config.run_status_filter is not None: + filter_config.run_status_filter=filter_config.run_status_filter.upper() + + db_execution_arr=AgentExecution.get_all_executions_by_filter_config(db.session, agent.id, filter_config) + + response_arr=[] + for ind_execution in db_execution_arr: + response_arr.append({"run_id":ind_execution.id, "status":ind_execution.status}) + + return response_arr + + +@router.post("/{agent_id}/pause",status_code=200) +def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + #Checking if the run_ids whose output files are requested belong to the organisation + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING") + for ind_execution in db_execution_arr: + ind_execution.status="PAUSED" + db.session.commit() + db.session.flush() + return { + "result":"success" + } + +@router.post("/{agent_id}/resume",status_code=200) +def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateChangeConfigIn,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + agent= Agent.get_active_agent_by_id(db.session, agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + + project=Project.find_by_id(db.session, agent.project_id) + if project.organisation_id!=organisation.id: + raise HTTPException(status_code=404, detail="Agent not found") + + if execution_state_change_input.run_ids is not None: + try: + AgentExecution.validate_run_ids(db.session,execution_state_change_input.run_ids,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED") + for ind_execution in db_execution_arr: + ind_execution.status="RUNNING" + + db.session.commit() + db.session.flush() + return { + "result":"success" + } + +@router.post("/resources/output",status_code=201) +def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): + if get_config('STORAGE_TYPE') != "S3": + raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured") + run_ids_arr=run_id_config.run_ids + if len(run_ids_arr)==0: + raise HTTPException(status_code=404, + detail=f"No execution_id found") + #Checking if the run_ids whose output files are requested belong to the organisation + try: + AgentExecution.validate_run_ids(db.session,run_ids_arr,organisation.id) + except Exception as e: + raise HTTPException(status_code=404, detail="One or more run_ids not found") + + db_resources_arr=Resource.find_by_run_ids(db.session, run_ids_arr) + + try: + response_obj=S3Helper().get_download_url_of_resources(db_resources_arr) + except: + raise HTTPException(status_code=401,detail="Invalid S3 credentials") + return response_obj + diff --git a/superagi/controllers/api_key.py b/superagi/controllers/api_key.py new file mode 100644 index 000000000..57e5c739b --- /dev/null +++ b/superagi/controllers/api_key.py @@ -0,0 +1,55 @@ +import json +import uuid +from fastapi import APIRouter, Body +from fastapi import HTTPException, Depends +from fastapi_jwt_auth import AuthJWT +from fastapi_sqlalchemy import db +from pydantic import BaseModel +from superagi.helper.auth import get_user_organisation +from superagi.helper.auth import check_auth +from superagi.models.api_key import ApiKey +from typing import Optional, Annotated +router = APIRouter() + +class ApiKeyIn(BaseModel): + id:int + name: str + class Config: + orm_mode = True + +class ApiKeyDeleteIn(BaseModel): + id:int + class Config: + orm_mode = True + +@router.post("") +def create_api_key(name: Annotated[str,Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): + api_key=str(uuid.uuid4()) + obj=ApiKey(key=api_key,name=name,org_id=organisation.id) + db.session.add(obj) + db.session.commit() + db.session.flush() + return {"api_key": api_key} + +@router.get("") +def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): + api_keys=ApiKey.get_by_org_id(db.session, organisation.id) + return api_keys + +@router.delete("/{api_key_id}") +def delete_api_key(api_key_id:int, Authorize: AuthJWT = Depends(check_auth)): + api_key=ApiKey.get_by_id(db.session, api_key_id) + if api_key is None: + raise HTTPException(status_code=404, detail="API key not found") + ApiKey.delete_by_id(db.session, api_key_id) + return {"success": True} + +@router.put("") +def edit_api_key(api_key_in:ApiKeyIn,Authorize: AuthJWT = Depends(check_auth)): + api_key=ApiKey.get_by_id(db.session, api_key_in.id) + if api_key is None: + raise HTTPException(status_code=404, detail="API key not found") + ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name) + return {"success": True} + + diff --git a/superagi/controllers/types/agent_with_config.py b/superagi/controllers/types/agent_with_config.py index f3995b5df..5ce81d211 100644 --- a/superagi/controllers/types/agent_with_config.py +++ b/superagi/controllers/types/agent_with_config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel from typing import List, Optional - +from superagi.controllers.types.agent_schedule import AgentScheduleInput class AgentConfigInput(BaseModel): name: str @@ -20,3 +20,45 @@ class AgentConfigInput(BaseModel): max_iterations: int user_timezone: Optional[str] knowledge: Optional[int] + + + +class AgentConfigExtInput(BaseModel): + name: str + description: str + project_id: Optional[int] + goal: List[str] + instruction: List[str] + agent_workflow: str + constraints: List[str] + tools: List[dict] + LTM_DB:Optional[str] + exit: Optional[str] + permission_type: Optional[str] + iteration_interval: int + model: str + schedule: Optional[AgentScheduleInput] + max_iterations: int + user_timezone: Optional[str] + knowledge: Optional[int] + +class AgentConfigUpdateExtInput(BaseModel): + name: Optional[str] + description: Optional[str] + project_id: Optional[int] + goal: Optional[List[str]] + instruction: Optional[List[str]] + agent_workflow: Optional[str] + constraints: Optional[List[str]] + tools: Optional[List[dict]] + LTM_DB:Optional[str] + exit: Optional[str] + permission_type: Optional[str] + iteration_interval: Optional[int] + model: Optional[str] + schedule: Optional[AgentScheduleInput] + max_iterations: Optional[int] + user_timezone: Optional[str] + knowledge: Optional[int] + + diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py new file mode 100644 index 000000000..0a55bd216 --- /dev/null +++ b/superagi/controllers/webhook.py @@ -0,0 +1,60 @@ +from datetime import datetime + +from fastapi import APIRouter +from fastapi import Depends +from fastapi_jwt_auth import AuthJWT +from fastapi_sqlalchemy import db +from pydantic import BaseModel + +# from superagi.types.db import AgentOut, AgentIn +from superagi.helper.auth import check_auth, get_user_organisation +from superagi.models.webhooks import Webhooks + +router = APIRouter() + + +class WebHookIn(BaseModel): + name: str + url: str + headers: dict + + class Config: + orm_mode = True + + +class WebHookOut(BaseModel): + id: int + org_id: int + name: str + url: str + headers: dict + is_deleted: bool + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + + +# CRUD Operations +@router.post("/add", response_model=WebHookOut, status_code=201) +def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth), + organisation=Depends(get_user_organisation)): + """ + Creates a new webhook + + Args: + + Returns: + Agent: An object of Agent representing the created Agent. + + Raises: + HTTPException (Status Code=404): If the associated project is not found. + """ + db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id, + is_deleted=False) + db.session.add(db_webhook) + db.session.commit() + db.session.flush() + + return db_webhook diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py index cc02d5643..f916a3ae5 100644 --- a/superagi/helper/auth.py +++ b/superagi/helper/auth.py @@ -1,10 +1,14 @@ -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Header, Security, status +from fastapi.security import APIKeyHeader from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db - +from fastapi.security.api_key import APIKeyHeader from superagi.config.config import get_config from superagi.models.organisation import Organisation from superagi.models.user import User +from superagi.models.api_key import ApiKey +from typing import Optional +from sqlalchemy import or_ def check_auth(Authorize: AuthJWT = Depends()): @@ -39,6 +43,7 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)): organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first() return organisation + def get_current_user(Authorize: AuthJWT = Depends(check_auth)): env = get_config("ENV", "DEV") @@ -50,4 +55,32 @@ def get_current_user(Authorize: AuthJWT = Depends(check_auth)): # Query the User table to find the user by their email user = db.session.query(User).filter(User.email == email).first() - return user \ No newline at end of file + return user + + +api_key_header = APIKeyHeader(name="X-API-Key") + + +def validate_api_key(api_key: str = Security(api_key_header)) -> str: + query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + if query_result is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API Key", + ) + + return query_result.key + + +def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Organisation: + query_result = db.session.query(ApiKey).filter(ApiKey.key == api_key, + or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + if query_result is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API Key", + ) + + organisation = db.session.query(Organisation).filter(Organisation.id == query_result.org_id).first() + return organisation \ No newline at end of file diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py index 2b2d0ba1f..2669b77ed 100644 --- a/superagi/helper/s3_helper.py +++ b/superagi/helper/s3_helper.py @@ -5,10 +5,9 @@ from superagi.config.config import get_config from superagi.lib.logger import logger +from urllib.parse import unquote import json - - class S3Helper: def __init__(self, bucket_name = get_config("BUCKET_NAME")): """ @@ -113,4 +112,27 @@ def upload_file_content(self, content, file_path): try: self.s3.put_object(Bucket=self.bucket_name, Key=file_path, Body=content) except: - raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") \ No newline at end of file + raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") + + def get_download_url_of_resources(self,db_resources_arr): + s3 = boto3.client( + 's3', + aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"), + ) + response_obj={} + for db_resource in db_resources_arr: + response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path) + content = response["Body"].read() + bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME") + file_name=db_resource.path.split('/')[-1] + file_name=''.join(char for char in file_name if char != "`") + object_key=f"public_resources/run_id{db_resource.agent_execution_id}/{file_name}" + s3.put_object(Bucket=bucket_name, Key=object_key, Body=content) + file_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}" + resource_execution_id=db_resource.agent_execution_id + if resource_execution_id in response_obj: + response_obj[resource_execution_id].append(file_url) + else: + response_obj[resource_execution_id]=[file_url] + return response_obj \ No newline at end of file diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py new file mode 100644 index 000000000..cf5e988d2 --- /dev/null +++ b/superagi/helper/webhook_manager.py @@ -0,0 +1,37 @@ +from superagi.models.agent import Agent +from superagi.models.agent_execution import AgentExecution +from superagi.models.webhooks import Webhooks +from superagi.models.webhook_events import WebhookEvents +import requests +import json +from superagi.lib.logger import logger +class WebHookManager: + def __init__(self,session): + self.session=session + + def agent_status_change_callback(self, agent_execution_id, curr_status, old_status): + if curr_status=="CREATED" or agent_execution_id is None: + return + agent_id=AgentExecution.get_agent_execution_from_id(self.session,agent_execution_id).agent_id + agent=Agent.get_agent_from_id(self.session,agent_id) + org=agent.get_agent_organisation(self.session) + org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all() + + for webhook_obj in org_webhooks: + webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"} + error=None + request=None + status='sent' + try: + request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) + except Exception as e: + logger.error(f"Exception occured in webhooks {e}") + error=str(e) + if request is not None and request.status_code not in [200,201] and error is None: + error=request.text + if error is not None: + status='Error' + webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error) + self.session.add(webhook_event) + self.session.commit() + diff --git a/superagi/models/agent.py b/superagi/models/agent.py index 38a8796fc..49bd93f9d 100644 --- a/superagi/models/agent.py +++ b/superagi/models/agent.py @@ -4,8 +4,10 @@ import json from sqlalchemy import Column, Integer, String, Boolean +from sqlalchemy import or_ from superagi.lib.logger import logger +from superagi.models.agent_config import AgentConfiguration from superagi.models.agent_template import AgentTemplate from superagi.models.agent_template_config import AgentTemplateConfig # from superagi.models import AgentConfiguration @@ -13,7 +15,7 @@ from superagi.models.organisation import Organisation from superagi.models.project import Project from superagi.models.workflows.agent_workflow import AgentWorkflow -from superagi.models.agent_config import AgentConfiguration + class Agent(DBBaseModel): """ @@ -35,8 +37,8 @@ class Agent(DBBaseModel): project_id = Column(Integer) description = Column(String) agent_workflow_id = Column(Integer) - is_deleted = Column(Boolean, default = False) - + is_deleted = Column(Boolean, default=False) + def __repr__(self): """ Returns a string representation of the Agent object. @@ -47,8 +49,8 @@ def __repr__(self): """ return f"Agent(id={self.id}, name='{self.name}', project_id={self.project_id}, " \ f"description='{self.description}', agent_workflow_id={self.agent_workflow_id}," \ - f"is_deleted='{self.is_deleted}')" - + f"is_deleted='{self.is_deleted}')" + @classmethod def fetch_configuration(cls, session, agent_id: int): """ @@ -105,7 +107,8 @@ def eval_agent_config(cls, key, value): """ - if key in ["name", "description", "exit", "model", "permission_type", "LTM_DB", "resource_summary", "knowledge"]: + if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", + "resource_summary", "knowledge"]: return value elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]: return int(value) @@ -150,7 +153,6 @@ def create_agent_with_config(cls, db, agent_with_config): # AgentWorkflow.name == "Fixed Task Queue").first() # db_agent.agent_workflow_id = agent_workflow.id - db.session.commit() # Create Agent Configuration @@ -291,3 +293,9 @@ def find_org_by_agent_id(cls, session, agent_id: int): agent = session.query(Agent).filter_by(id=agent_id).first() project = session.query(Project).filter(Project.id == agent.project_id).first() return session.query(Organisation).filter(Organisation.id == project.organisation_id).first() + + @classmethod + def get_active_agent_by_id(cls, session, agent_id: int): + db_agent = session.query(Agent).filter(Agent.id == agent_id, + or_(Agent.is_deleted == False, Agent.is_deleted is None)).first() + return db_agent diff --git a/superagi/models/agent_execution.py b/superagi/models/agent_execution.py index f95afb85b..5cba5f509 100644 --- a/superagi/models/agent_execution.py +++ b/superagi/models/agent_execution.py @@ -164,4 +164,49 @@ def assign_next_step_id(cls, session, agent_execution_id: int, next_step_id: int if next_step.action_type == "ITERATION_WORKFLOW": trigger_step = IterationWorkflow.fetch_trigger_step_id(session, next_step.action_reference_id) agent_execution.iteration_workflow_step_id = trigger_step.id - session.commit() \ No newline at end of file + session.commit() + + @classmethod + def get_execution_by_agent_id_and_status(cls, session, agent_id: int, status_filter: str): + db_agent_execution = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == status_filter).first() + return db_agent_execution + + + @classmethod + def get_all_executions_by_status_and_agent_id(cls, session, agent_id, execution_state_change_input, current_status: str): + db_execution_arr = [] + if execution_state_change_input.run_ids is not None: + db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status,AgentExecution.id.in_(execution_state_change_input.run_ids)).all() + else: + db_execution_arr = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id, AgentExecution.status == current_status).all() + return db_execution_arr + + @classmethod + def get_all_executions_by_filter_config(cls, session, agent_id: int, filter_config): + db_execution_query = session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id) + if filter_config.run_ids is not None: + db_execution_query = db_execution_query.filter(AgentExecution.id.in_(filter_config.run_ids)) + + if filter_config.run_status_filter is not None and filter_config.run_status_filter in ["CREATED", "RUNNING", + "PAUSED", "COMPLETED", + "TERMINATED"]: + db_execution_query = db_execution_query.filter(AgentExecution.status == filter_config.run_status_filter) + + db_execution_arr = db_execution_query.all() + return db_execution_arr + + @classmethod + def validate_run_ids(cls, session, run_ids: list, organisation_id: int): + from superagi.models.agent import Agent + from superagi.models.project import Project + + run_ids=list(set(run_ids)) + agent_ids=session.query(AgentExecution.agent_id).filter(AgentExecution.id.in_(run_ids)).distinct().all() + agent_ids = [id for (id,) in agent_ids] + project_ids=session.query(Agent.project_id).filter(Agent.id.in_(agent_ids)).distinct().all() + project_ids = [id for (id,) in project_ids] + org_ids=session.query(Project.organisation_id).filter(Project.id.in_(project_ids)).distinct().all() + org_ids = [id for (id,) in org_ids] + + if len(org_ids) > 1 or org_ids[0] != organisation_id: + raise Exception(f"one or more run IDs not found") diff --git a/superagi/models/agent_schedule.py b/superagi/models/agent_schedule.py index 6415c375c..32e8dcd84 100644 --- a/superagi/models/agent_schedule.py +++ b/superagi/models/agent_schedule.py @@ -45,4 +45,27 @@ def __repr__(self): f"expiry_date={self.expiry_date}, " \ f"expiry_runs={self.expiry_runs}), " \ f"current_runs={self.expiry_runs}), " \ - f"status={self.status}), " \ No newline at end of file + f"status={self.status}), " + + @classmethod + def save_schedule_from_config(cls, session, db_agent, schedule_config: AgentScheduleInput): + agent_schedule = AgentSchedule( + agent_id=db_agent.id, + start_time=schedule_config.start_time, + next_scheduled_time=schedule_config.start_time, + recurrence_interval=schedule_config.recurrence_interval, + expiry_date=schedule_config.expiry_date, + expiry_runs=schedule_config.expiry_runs, + current_runs=0, + status="SCHEDULED" + ) + + agent_schedule.agent_id = db_agent.id + session.add(agent_schedule) + session.commit() + return agent_schedule + + @classmethod + def find_by_agent_id(cls, session, agent_id: int): + db_schedule=session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id).first() + return db_schedule diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py new file mode 100644 index 000000000..1cc3e310a --- /dev/null +++ b/superagi/models/api_key.py @@ -0,0 +1,46 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution +from sqlalchemy import func, or_ + +class ApiKey(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'api_keys' + + id = Column(Integer, primary_key=True) + org_id = Column(Integer) + name = Column(String) + key = Column(String) + is_expired= Column(Boolean) + + @classmethod + def get_by_org_id(cls, session, org_id: int): + db_api_keys=session.query(ApiKey).filter(ApiKey.org_id==org_id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).all() + return db_api_keys + + @classmethod + def get_by_id(cls, session, id: int): + db_api_key=session.query(ApiKey).filter(ApiKey.id==id,or_(ApiKey.is_expired == False, ApiKey.is_expired == None)).first() + return db_api_key + + @classmethod + def delete_by_id(cls, session,id: int): + db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first() + db_api_key.is_expired = True + session.commit() + session.flush() + + @classmethod + def update_api_key(cls, session, id: int, name: str): + db_api_key = session.query(ApiKey).filter(ApiKey.id == id).first() + db_api_key.name = name + session.commit() + session.flush() + diff --git a/superagi/models/project.py b/superagi/models/project.py index 1de1eb624..c798e1d1a 100644 --- a/superagi/models/project.py +++ b/superagi/models/project.py @@ -55,3 +55,13 @@ def find_or_create_default_project(cls, session, organisation_id): else: default_project = project return default_project + + @classmethod + def find_by_org_id(cls, session, org_id: int): + project = session.query(Project).filter(Project.organisation_id == org_id).first() + return project + + @classmethod + def find_by_id(cls, session, project_id: int): + project = session.query(Project).filter(Project.id == project_id).first() + return project \ No newline at end of file diff --git a/superagi/models/resource.py b/superagi/models/resource.py index 926123e47..78713b690 100644 --- a/superagi/models/resource.py +++ b/superagi/models/resource.py @@ -58,6 +58,11 @@ def validate_resource_type(storage_type): if storage_type not in valid_types: raise InvalidResourceType("Invalid resource type") - + + @classmethod + def find_by_run_ids(cls, session, run_ids: list): + db_resources_arr=session.query(Resource).filter(Resource.agent_execution_id.in_(run_ids)).all() + return db_resources_arr + class InvalidResourceType(Exception): """Custom exception for invalid resource type""" diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py index 2c89a38ab..5a9c0a0e9 100644 --- a/superagi/models/toolkit.py +++ b/superagi/models/toolkit.py @@ -138,3 +138,26 @@ def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids): if tool is not None: agent_toolkit_tools.append(tool.id) return agent_toolkit_tools + + @classmethod + def get_tool_and_toolkit_arr(cls, session, agent_config_tools_arr: list): + from superagi.models.tool import Tool + toolkits_arr= set() + tools_arr= set() + for tool_obj in agent_config_tools_arr: + toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip()).first() + if toolkit is None: + raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.") + toolkits_arr.add(toolkit.id) + if tool_obj.get("tools"): + for tool_name_str in tool_obj["tools"]: + tool_db_obj=session.query(Tool).filter(Tool.name == tool_name_str.strip()).first() + if tool_db_obj is None: + raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.") + + tools_arr.add(tool_db_obj.id) + else: + tools=Tool.get_toolkit_tools(session, toolkit.id) + for tool_db_obj in tools: + tools_arr.add(tool_db_obj.id) + return list(tools_arr) diff --git a/superagi/models/webhook_events.py b/superagi/models/webhook_events.py new file mode 100644 index 000000000..af90ea492 --- /dev/null +++ b/superagi/models/webhook_events.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution + + +class WebhookEvents(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'webhook_events' + + id = Column(Integer, primary_key=True) + agent_id=Column(Integer) + run_id = Column(Integer) + event = Column(String) + status = Column(String) + errors= Column(Text) + + + diff --git a/superagi/models/webhooks.py b/superagi/models/webhooks.py new file mode 100644 index 000000000..14d683472 --- /dev/null +++ b/superagi/models/webhooks.py @@ -0,0 +1,22 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey,JSON +from sqlalchemy.orm import relationship +from sqlalchemy.dialects.postgresql import JSONB +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution + +class Webhooks(DBBaseModel): + """ + + Attributes: + + + Methods: + """ + __tablename__ = 'webhooks' + + id = Column(Integer, primary_key=True) + name=Column(String) + org_id = Column(Integer) + url = Column(String) + headers=Column(JSON) + is_deleted=Column(Boolean) diff --git a/superagi/worker.py b/superagi/worker.py index 9e41c2fd8..b6e5ea231 100644 --- a/superagi/worker.py +++ b/superagi/worker.py @@ -15,6 +15,10 @@ from superagi.models.db import connect_db from superagi.types.model_source_types import ModelSourceType +from sqlalchemy import event +from superagi.models.agent_execution import AgentExecution +from superagi.helper.webhook_manager import WebHookManager + redis_url = get_config('REDIS_URL', 'super__redis:6379') app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"]) @@ -32,9 +36,16 @@ } app.conf.beat_schedule = beat_schedule +@event.listens_for(AgentExecution.status, "set") +def agent_status_change(target, val,old_val,initiator): + if get_config("IN_TESTING",False): + webhook_callback.delay(target.id,val,old_val) + + @app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5) def initialize_schedule_agent_task(): """Executing agent scheduling in the background.""" + schedule_helper = AgentScheduleHelper() schedule_helper.update_next_scheduled_time() schedule_helper.run_scheduled_agents() @@ -49,7 +60,7 @@ def execute_agent(agent_execution_id: int, time): AgentExecutor().execute_next_step(agent_execution_id=agent_execution_id) -@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5, serializer='pickle') +@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle') def summarize_resource(agent_id: int, resource_id: int): """Summarize a resource in background.""" from superagi.resource_manager.resource_summary import ResourceSummarizer @@ -77,3 +88,11 @@ def summarize_resource(agent_id: int, resource_id: int): resource_summarizer.add_to_vector_store_and_create_summary(resource_id=resource_id, documents=documents) session.close() + +@app.task(name="webhook_callback", autoretry_for=(Exception,), retry_backoff=2, max_retries=5,serializer='pickle') +def webhook_callback(agent_execution_id,val,old_val): + engine = connect_db() + Session = sessionmaker(bind=engine) + with Session() as session: + WebHookManager(session).agent_status_change_callback(agent_execution_id, val, old_val) + diff --git a/tests/unit_tests/controllers/api/__init__.py b/tests/unit_tests/controllers/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/controllers/api/test_agent.py b/tests/unit_tests/controllers/api/test_agent.py new file mode 100644 index 000000000..99fda13b2 --- /dev/null +++ b/tests/unit_tests/controllers/api/test_agent.py @@ -0,0 +1,220 @@ +import pytest +from fastapi.testclient import TestClient +from fastapi import HTTPException + +import superagi.config.config +from unittest.mock import MagicMock, patch,Mock +from main import app +from unittest.mock import patch,create_autospec +from sqlalchemy.orm import Session +from superagi.controllers.api.agent import ExecutionStateChangeConfigIn,AgentConfigUpdateExtInput +from superagi.models.agent import Agent +from superagi.models.project import Project + +client = TestClient(app) + +@pytest.fixture +def mock_api_key_get(): + mock_api_key = "your_mock_api_key" + return mock_api_key +@pytest.fixture +def mock_execution_state_change_input(): + return { + + } +@pytest.fixture +def mock_run_id_config(): + return { + "run_ids":[1,2] + } + +@pytest.fixture +def mock_agent_execution(): + return { + + } +@pytest.fixture +def mock_run_id_config_empty(): + return { + "run_ids":[] + } + +@pytest.fixture +def mock_run_id_config_invalid(): + return { + "run_ids":[12310] + } +@pytest.fixture +def mock_agent_config_update_ext_input(): + return AgentConfigUpdateExtInput( + tools=[{"name":"Image Generation Toolkit"}], + schedule=None, + goal=["Test Goal"], + instruction=["Test Instruction"], + constraints=["Test Constraints"], + iteration_interval=10, + model="Test Model", + max_iterations=100, + agent_type="Test Agent Type" + ) + +@pytest.fixture +def mock_update_agent_config(): + return { + "name": "agent_3_UPDATED", + "description": "AI assistant to solve complex problems", + "goal": ["create a photo of a cat"], + "agent_type": "Dynamic Task Workflow", + "constraints": [ + "~4000 word limit for short term memory.", + "Your long term memory is short, so immediately save important information to files.", + "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.", + "No user assistance", + "Exclusively use the commands listed in double quotes e.g. \"command name\"" + ], + "instruction": ["Be accurate"], + "tools":[ + { + "name":"Image Generation Toolkit" + } + ], + "iteration_interval": 500, + "model": "gpt-4", + "max_iterations": 100 + } +# Define test cases + +def test_update_agent_not_found(mock_update_agent_config,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.put( + "/v1/agent/1", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_update_agent_config + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + + +def test_get_run_resources_no_run_ids(mock_run_id_config_empty,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock, \ + patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "v1/agent/resources/output", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_run_id_config_empty + ) + assert response.status_code == 404 + assert response.text == '{"detail":"No execution_id found"}' + +def test_get_run_resources_invalid_run_ids(mock_run_id_config_invalid,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock, \ + patch('superagi.controllers.api.agent.get_config', return_value="S3") as mock_get_config: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "v1/agent/resources/output", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_run_id_config_invalid + ) + assert response.status_code == 404 + assert response.text == '{"detail":"One or more run_ids not found"}' + +def test_resume_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/resume", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_execution_state_change_input + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + + +def test_pause_agent_runs_agent_not_found(mock_execution_state_change_input,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/pause", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_execution_state_change_input + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + +def test_create_run_agent_not_found(mock_agent_execution,mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session + mock_session = create_autospec(Session) + # # Configure session query methods to return None for agent + mock_session.query.return_value.filter.return_value.first.return_value = None + response = client.post( + "/v1/agent/1/run", + headers={"X-API-Key": mock_api_key_get}, # Provide the mock API key in headers + json=mock_agent_execution + ) + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' + +def test_create_run_project_not_matching_org(mock_agent_execution, mock_api_key_get): + with patch('superagi.helper.auth.get_organisation_from_api_key') as mock_get_user_org, \ + patch('superagi.helper.auth.validate_api_key') as mock_validate_api_key, \ + patch('superagi.helper.auth.db') as mock_auth_db, \ + patch('superagi.controllers.api.agent.db') as db_mock: + + # Mock the session and configure query methods to return agent and project + mock_session = create_autospec(Session) + mock_agent = Agent(id=1, project_id=1, agent_workflow_id=1) + mock_session.query.return_value.filter.return_value.first.return_value = mock_agent + mock_project = Project(id=1, organisation_id=2) # Different organisation ID + db_mock.Project.find_by_id.return_value = mock_project + db_mock.session.return_value.__enter__.return_value = mock_session + + response = client.post( + "/v1/agent/1/run", + headers={"X-API-Key": mock_api_key_get}, + json=mock_agent_execution + ) + + assert response.status_code == 404 + assert response.text == '{"detail":"Agent not found"}' diff --git a/tests/unit_tests/models/test_agent.py b/tests/unit_tests/models/test_agent.py index da7e6c4c1..18d614ce0 100644 --- a/tests/unit_tests/models/test_agent.py +++ b/tests/unit_tests/models/test_agent.py @@ -22,6 +22,27 @@ def test_get_agent_from_id(): # Assert that the returned agent object matches the mock agent assert agent == mock_agent + +def test_get_active_agent_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample agent ID + agent_id = 1 + + # Create a mock agent object to be returned by the session query + mock_agent = Agent(id=agent_id, name="Test Agent", project_id=1, description="Agent for testing",is_deleted=False) + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent + + # Call the method under test + agent = Agent.get_active_agent_by_id(session, agent_id) + + # Assert that the returned agent object matches the mock agent + assert agent == mock_agent + assert agent.is_deleted == False + def test_eval_tools_key(): key = "tools" value = "[1, 2, 3]" diff --git a/tests/unit_tests/models/test_agent_execution.py b/tests/unit_tests/models/test_agent_execution.py index a2c91e581..3ecbdc84f 100644 --- a/tests/unit_tests/models/test_agent_execution.py +++ b/tests/unit_tests/models/test_agent_execution.py @@ -9,8 +9,6 @@ from superagi.models.agent_execution import AgentExecution from superagi.models.workflows.agent_workflow_step import AgentWorkflowStep from superagi.models.workflows.iteration_workflow import IterationWorkflow - - def test_get_agent_execution_from_id(): # Create a mock session session = create_autospec(Session) @@ -89,3 +87,28 @@ def test_assign_next_step_id(mock_session, mocker): # Check that the attributes were updated assert mock_execution.current_agent_step_id == 2 assert mock_execution.iteration_workflow_step_id == 3 + +def test_get_execution_by_agent_id_and_status(): + session = create_autospec(Session) + + # Create a sample agent execution ID + agent_execution_id = 1 + + # Create a mock agent execution object to be returned by the session query + mock_agent_execution = AgentExecution(id=agent_execution_id, name="Test Execution", status="RUNNING") + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent_execution + + # Call the method under test + agent_execution = AgentExecution.get_execution_by_agent_id_and_status(session, agent_execution_id,"RUNNING") + + # Assert that the returned agent object matches the mock agent + assert agent_execution == mock_agent_execution + assert agent_execution.status == "RUNNING" + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + diff --git a/tests/unit_tests/models/test_agent_schedule.py b/tests/unit_tests/models/test_agent_schedule.py new file mode 100644 index 000000000..2f5f84355 --- /dev/null +++ b/tests/unit_tests/models/test_agent_schedule.py @@ -0,0 +1,23 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.agent_schedule import AgentSchedule + +def test_find_by_agent_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample agent ID + agent_id = 1 + + # Create a mock agent schedule object to be returned by the session query + mock_agent_schedule = AgentSchedule(id=1,agent_id=agent_id, start_time="2023-08-10 12:17:00", recurrence_interval="2 Minutes", expiry_runs=2) + + # Configure the session query to return the mock agent + session.query.return_value.filter.return_value.first.return_value = mock_agent_schedule + + # Call the method under test + agent_schedule = AgentSchedule.find_by_agent_id(session, agent_id) + + # Assert that the returned agent object matches the mock agent + assert agent_schedule == mock_agent_schedule diff --git a/tests/unit_tests/models/test_api_key.py b/tests/unit_tests/models/test_api_key.py new file mode 100644 index 000000000..fe5752f9d --- /dev/null +++ b/tests/unit_tests/models/test_api_key.py @@ -0,0 +1,93 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.api_key import ApiKey + +def test_get_by_org_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample organization ID + org_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_keys = [ + ApiKey(id=1, org_id=org_id, key="key1", is_expired=False), + ApiKey(id=2, org_id=org_id, key="key2", is_expired=False), + ] + + # Configure the session query to return the mock api keys + session.query.return_value.filter.return_value.all.return_value = mock_api_keys + + # Call the method under test + api_keys = ApiKey.get_by_org_id(session, org_id) + + # Assert that the returned api keys match the mock api keys + assert api_keys == mock_api_keys + + +def test_get_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID + api_key_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + api_key = ApiKey.get_by_id(session, api_key_id) + + # Assert that the returned api key matches the mock api key + assert api_key == mock_api_key + +def test_delete_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID + api_key_id = 1 + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + ApiKey.delete_by_id(session, api_key_id) + + # Assert that the api key's is_expired attribute is set to True + assert mock_api_key.is_expired == True + + # Assert that the session.commit and session.flush methods were called + session.commit.assert_called_once() + session.flush.assert_called_once() + +def test_edit_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample api key ID and new name + api_key_id = 1 + new_name = "New Name" + + # Create a mock ApiKey object to be returned by the session query + mock_api_key = ApiKey(id=api_key_id, org_id=1, key="key1", is_expired=False) + + # Configure the session query to return the mock api key + session.query.return_value.filter.return_value.first.return_value = mock_api_key + + # Call the method under test + ApiKey.update_api_key(session, api_key_id, new_name) + + # Assert that the api key's name attribute is updated + assert mock_api_key.name == new_name + + # Assert that the session.commit and session.flush methods were called + session.commit.assert_called_once() + session.flush.assert_called_once() \ No newline at end of file diff --git a/tests/unit_tests/models/test_project.py b/tests/unit_tests/models/test_project.py new file mode 100644 index 000000000..9ac868217 --- /dev/null +++ b/tests/unit_tests/models/test_project.py @@ -0,0 +1,42 @@ +from unittest.mock import create_autospec + +from sqlalchemy.orm import Session +from superagi.models.project import Project + +def test_find_by_org_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample org ID + org_id = 123 + + # Create a mock project object to be returned by the session query + mock_project = Project(id=1, name="Test Project", organisation_id=org_id, description="Project for testing") + + # Configure the session query to return the mock project + session.query.return_value.filter.return_value.first.return_value = mock_project + + # Call the method under test + project = Project.find_by_org_id(session, org_id) + + # Assert that the returned project object matches the mock project + assert project == mock_project + +def test_find_by_id(): + # Create a mock session + session = create_autospec(Session) + + # Create a sample project ID + project_id = 123 + + # Create a mock project object to be returned by the session query + mock_project = Project(id=project_id, name="Test Project", organisation_id=1, description="Project for testing") + + # Configure the session query to return the mock project + session.query.return_value.filter.return_value.first.return_value = mock_project + + # Call the method under test + project = Project.find_by_id(session, project_id) + + # Assert that the returned project object matches the mock project + assert project == mock_project \ No newline at end of file diff --git a/tests/unit_tests/models/test_toolkit.py b/tests/unit_tests/models/test_toolkit.py index 82302d97e..339c970c9 100644 --- a/tests/unit_tests/models/test_toolkit.py +++ b/tests/unit_tests/models/test_toolkit.py @@ -1,10 +1,11 @@ -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch, call,create_autospec,Mock import pytest from superagi.models.organisation import Organisation from superagi.models.toolkit import Toolkit from superagi.models.tool import Tool +from sqlalchemy.orm import Session @pytest.fixture def mock_session(): @@ -243,3 +244,23 @@ def test_fetch_tool_ids_from_toolkit(mock_tool, mock_session): # Assert assert result == [mock_tool.id for _ in toolkit_ids] + +def test_get_tool_and_toolkit_arr_with_nonexistent_toolkit(): + # Create a mock session + session = create_autospec(Session) + + # Configure the session query to return None for toolkit + session.query.return_value.filter.return_value.first.return_value = None + + # Call the method under test with a non-existent toolkit + agent_config_tools_arr = [ + {"name": "NonExistentToolkit", "tools": ["Tool1", "Tool2"]}, + ] + + # Use a context manager to capture the raised exception and its message + with pytest.raises(Exception) as exc_info: + Toolkit.get_tool_and_toolkit_arr(session, agent_config_tools_arr) + + # Assert that the expected error message is contained within the raised exception message + expected_error_message = "One or more of the Tool(s)/Toolkit(s) does not exist." + assert expected_error_message in str(exc_info.value) From 0eb1ba0c0911f46c88ea37577c872165c9c21eeb Mon Sep 17 00:00:00 2001 From: Fluder-Paradyne <121793617+Fluder-Paradyne@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:07:37 +0530 Subject: [PATCH 34/34] Update docker-compose-dev.yaml (#1073) --- docker-compose-dev.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml index d3caf6029..66a7086d1 100644 --- a/docker-compose-dev.yaml +++ b/docker-compose-dev.yaml @@ -28,10 +28,10 @@ services: NEXT_PUBLIC_API_BASE_URL: "/api" networks: - super_network - volumes: - - ./gui:/app - - /app/node_modules/ - - /app/.next/ +# volumes: +# - ./gui:/app +# - /app/node_modules/ +# - /app/.next/ super__redis: image: "redis/redis-stack-server:latest" networks: @@ -73,4 +73,4 @@ networks: driver: bridge volumes: superagi_postgres_data: - redis_data: \ No newline at end of file + redis_data: