diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py index 4f4da0843..5230a2776 100644 --- a/superagi/controllers/api/agent.py +++ b/superagi/controllers/api/agent.py @@ -57,7 +57,7 @@ def create_agent_with_config(agent_with_config: AgentConfigExtInput, api_key: str = Security(validate_api_key), organisation:Organisation = Depends(get_organisation_from_api_key)): project=Project.find_by_org_id(db.session, organisation.id) try: - tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,organisation.id,agent_with_config.tools) except Exception as e: raise HTTPException(status_code=404, detail=str(e)) @@ -177,7 +177,7 @@ def update_agent(agent_id: int, agent_with_config: AgentConfigUpdateExtInput,api raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot update") try: - tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,agent_with_config.tools) + tools_arr=Toolkit.get_tool_and_toolkit_arr(db.session,organisation.id,agent_with_config.tools) except Exception as e: raise HTTPException(status_code=404,detail=str(e)) diff --git a/superagi/models/toolkit.py b/superagi/models/toolkit.py index 5a9c0a0e9..9246111a1 100644 --- a/superagi/models/toolkit.py +++ b/superagi/models/toolkit.py @@ -140,12 +140,12 @@ def fetch_tool_ids_from_toolkit(cls, session, toolkit_ids): return agent_toolkit_tools @classmethod - def get_tool_and_toolkit_arr(cls, session, agent_config_tools_arr: list): + def get_tool_and_toolkit_arr(cls, session, organisation_id :int,agent_config_tools_arr: list): from superagi.models.tool import Tool toolkits_arr= set() tools_arr= set() for tool_obj in agent_config_tools_arr: - toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip()).first() + toolkit=session.query(Toolkit).filter(Toolkit.name == tool_obj["name"].strip(), Toolkit.organisation_id == organisation_id).first() if toolkit is None: raise Exception("One or more of the Tool(s)/Toolkit(s) does not exist.") toolkits_arr.add(toolkit.id) diff --git a/tests/unit_tests/models/test_toolkit.py b/tests/unit_tests/models/test_toolkit.py index 339c970c9..297ef8716 100644 --- a/tests/unit_tests/models/test_toolkit.py +++ b/tests/unit_tests/models/test_toolkit.py @@ -259,7 +259,7 @@ def test_get_tool_and_toolkit_arr_with_nonexistent_toolkit(): # Use a context manager to capture the raised exception and its message with pytest.raises(Exception) as exc_info: - Toolkit.get_tool_and_toolkit_arr(session, agent_config_tools_arr) + Toolkit.get_tool_and_toolkit_arr(session,1, agent_config_tools_arr) # Assert that the expected error message is contained within the raised exception message expected_error_message = "One or more of the Tool(s)/Toolkit(s) does not exist."