diff --git a/gui/pages/Content/Agents/ActionConsole.js b/gui/pages/Content/Agents/ActionConsole.js index a1588a7e6..4dc621bcf 100644 --- a/gui/pages/Content/Agents/ActionConsole.js +++ b/gui/pages/Content/Agents/ActionConsole.js @@ -1,83 +1,116 @@ -import React from 'react'; +import React, { useState, useEffect } from 'react'; import styles from './Agents.module.css'; import Image from "next/image"; +import { updatePermissions } from '@/pages/api/DashboardService'; -export default function ActionConsole() { - const actions = [ - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - type: "text", - timeStamp: "2min ago" - }, - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam.", - type: "boolean", - timeStamp: "2min ago" - }, - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - type: "text", - timeStamp: "2min ago" - }, - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam.", - type: "boolean", - timeStamp: "2min ago" - }, - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - type: "text", - timeStamp: "2min ago" - }, - { - title: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam.", - type: "boolean", - timeStamp: "2min ago" - } - ] +export default function ActionConsole({ actions }) { + const [hiddenActions, setHiddenActions] = useState([]); + const [reasons, setReasons] = useState(actions.map(() => '')); + const [localActions, setLocalActions] = useState(actions); + const [denied, setDenied] = useState([]); + const [localActionIds, setLocalActionIds] = useState([]); - return (<> -
- {actions.map((action, index) => (
- {action.type === "notification" &&
-
{action.title}
-
} - {action.type === "boolean" &&
-
{action.title}
-
- -
-
- -
-
} - {action.type === "text" &&
-
{action.title}
-
-
-
- -
-
- -
-
-
} -
-
- schedule-icon -
-
{action.timeStamp}
-
-
))} -
- ) + useEffect(() => { + const updatedActions = actions.filter( + (action) => !localActionIds.includes(action.id) + ); + + if (updatedActions.length > 0) { + setLocalActions( + localActions.map((localAction) => + updatedActions.find(({ id }) => id === localAction.id) || localAction + ) + ); + + const updatedDenied = updatedActions.map(() => false); + const updatedReasons = updatedActions.map(() => ''); + + setDenied((prev) => prev.map((value, index) => updatedDenied[index] || value)); + setReasons((prev) => prev.map((value, index) => updatedReasons[index] || value)); + + setLocalActionIds([...localActionIds, ...updatedActions.map(({ id }) => id)]); + } + }, [actions]); + + const handleDeny = index => { + const newDeniedState = [...denied]; + newDeniedState[index] = !newDeniedState[index]; + setDenied(newDeniedState); + }; + + const formatDate = (dateString) => { + const now = new Date(); + const date = new Date(dateString); + const seconds = Math.floor((now - date) / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + const days = Math.floor(hours / 24); + const weeks = Math.floor(days / 7); + const months = Math.floor(days / 30); + const years = Math.floor(days / 365); + + if (years > 0) return `${years} yr${years === 1 ? '' : 's'}`; + if (months > 0) return `${months} mon${months === 1 ? '' : 's'}`; + if (weeks > 0) return `${weeks} wk${weeks === 1 ? '' : 's'}`; + if (days > 0) return `${days} day${days === 1 ? '' : 's'}`; + if (hours > 0) return `${hours} hr${hours === 1 ? '' : 's'}`; + if (minutes > 0) return `${minutes} min${minutes === 1 ? '' : 's'}`; + + return `${seconds} sec${seconds === 1 ? '' : 's'}`; + }; + + const handleSelection = (index, status, permissionId) => { + setHiddenActions([...hiddenActions, index]); + + const data = { + status: status, + user_feedback: reasons[index], + }; + + updatePermissions(permissionId, data).then((response) => { + console.log("voila") + }); + }; + + return ( + <> + {actions.some(action => action.status === "PENDING") ? (
+ {actions.map((action, index) => action.status === "PENDING" && !hiddenActions.includes(index) && ( +
+
+
Tool {action.tool_name} is seeking for Permissions
+ {denied[index] && ( +
+
Provide Feedback (Optional)
+ {const newReasons = [...reasons];newReasons[index] = e.target.value;setReasons(newReasons);}} placeholder="Enter your input here" className="input_medium" /> +
+ )} + {denied[index] ? ( +
+ + +
+ ) : ( +
+ + +
+ )} +
+
+
+ schedule-icon +
+
{formatDate(action.created_at)}
+
+
+ ))} +
): + ( +
+ no permissions + No Actions to Display! +
)} + + ); } \ No newline at end of file diff --git a/gui/pages/Content/Agents/ActivityFeed.js b/gui/pages/Content/Agents/ActivityFeed.js index 318a85b2f..f0ce5f5b0 100644 --- a/gui/pages/Content/Agents/ActivityFeed.js +++ b/gui/pages/Content/Agents/ActivityFeed.js @@ -5,7 +5,7 @@ import Image from "next/image"; import {formatTime} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; -export default function ActivityFeed({selectedRunId, selectedView}) { +export default function ActivityFeed({selectedRunId, selectedView, setFetchedData }) { const [loadingText, setLoadingText] = useState("Thinking"); const [feeds, setFeeds] = useState([]); const feedContainerRef = useRef(null); @@ -65,6 +65,7 @@ export default function ActivityFeed({selectedRunId, selectedView}) { const data = response.data; setFeeds(data.feeds); setRunStatus(data.status); + setFetchedData(data.permissions); }) .catch((error) => { console.error('Error fetching execution feeds:', error); diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 671e63bfc..dca97e233 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -61,7 +61,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen const databaseRef = useRef(null); const [databaseDropdown, setDatabaseDropdown] = useState(false); - const permissions = ["God Mode"] + const permissions = ["God Mode","RESTRICTED (Will ask for permission before using any tool)"] const [permission, setPermission] = useState(permissions[0]); const permissionRef = useRef(null); const [permissionDropdown, setPermissionDropdown] = useState(false); @@ -339,6 +339,12 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen setCreateClickable(false); + // if permission has word restricted change the permission to + let permission_type = permission; + if (permission.includes("RESTRICTED")) { + permission_type = "RESTRICTED"; + } + const agentData = { "name": agentName, "project_id": selectedProjectId, @@ -352,7 +358,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen "iteration_interval": stepTime, "model": model, "max_iterations": maxIterations, - "permission_type": permission, + "permission_type": permission_type, "LTM_DB": longTermMemory ? database : null, "memory_window": rollingWindow }; diff --git a/gui/pages/Content/Agents/AgentWorkspace.js b/gui/pages/Content/Agents/AgentWorkspace.js index d54b9bb66..cdecc6226 100644 --- a/gui/pages/Content/Agents/AgentWorkspace.js +++ b/gui/pages/Content/Agents/AgentWorkspace.js @@ -24,6 +24,7 @@ export default function AgentWorkspace({agentId, selectedView}) { const [agentDetails, setAgentDetails] = useState(null) const [agentExecutions, setAgentExecutions] = useState(null) const [dropdown, setDropdown] = useState(false); + const [fetchedData, setFetchedData] = useState(null); const [instructions, setInstructions] = useState(['']); const addInstruction = () => { @@ -216,18 +217,24 @@ export default function AgentWorkspace({agentId, selectedView}) {
- {leftPanel === 'activity_feed' &&
} + {leftPanel === 'activity_feed' &&
+ +
} {leftPanel === 'agent_type' &&
}
- {/*
*/} - {/* */} - {/*
*/} +
+ +
{/*
*/} {/*
- {rightPanel === 'action_console' && agentDetails && agentDetails?.permission_type !== 'God Mode' &&
} + {rightPanel === 'action_console' && agentDetails && agentDetails?.permission_type !== 'God Mode' && ( +
+ +
+ )} {rightPanel === 'details' &&
} {rightPanel === 'resource_manager' &&
}
diff --git a/gui/pages/Content/Agents/Agents.module.css b/gui/pages/Content/Agents/Agents.module.css index cdc311956..d0ef499b9 100644 --- a/gui/pages/Content/Agents/Agents.module.css +++ b/gui/pages/Content/Agents/Agents.module.css @@ -344,4 +344,14 @@ font-size: 13px; font-weight: 500; font-family: 'Source Code Pro'; -} \ No newline at end of file +} + +.text_12_n +{ + font-style: normal; + font-weight: 400; + font-size: 12px; + line-height: 14px; + color: #FFFFFF; + margin-top: 2px; +} diff --git a/gui/pages/_app.css b/gui/pages/_app.css index a6e5db5b1..06d049f47 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -487,7 +487,10 @@ p { border: 1px solid rgba(255, 255, 255, 0.14); text-align: center; padding: 5px 15px; - display: -webkit-box; + display: -webkit-flex; + flex-direction: row; + align-items: center; + gap: 6px; -webkit-box-orient: vertical; -webkit-line-clamp: 1; overflow: hidden; diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 502c65aca..8761d35a4 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -102,4 +102,8 @@ export const fetchAgentTemplateConfigLocal = (templateId) => { export const installAgentTemplate = (templateId) => { return api.post(`agent_templates/download?agent_template_id=${templateId}`); +} + +export const updatePermissions = (permissionId, data) => { + return api.put(`/agentexecutionpermissions/update/status/${permissionId}`, data) } \ No newline at end of file diff --git a/gui/public/images/check.svg b/gui/public/images/check.svg new file mode 100644 index 000000000..27c0772fa --- /dev/null +++ b/gui/public/images/check.svg @@ -0,0 +1,3 @@ + + + diff --git a/gui/public/images/no_permissions.svg b/gui/public/images/no_permissions.svg new file mode 100644 index 000000000..425b25211 --- /dev/null +++ b/gui/public/images/no_permissions.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/gui/public/images/undo.svg b/gui/public/images/undo.svg new file mode 100644 index 000000000..98b226fa4 --- /dev/null +++ b/gui/public/images/undo.svg @@ -0,0 +1,3 @@ + + + diff --git a/main.py b/main.py index 19ed102a1..502c8b107 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,7 @@ from superagi.controllers.agent_config import router as agent_config_router from superagi.controllers.agent_execution import router as agent_execution_router from superagi.controllers.agent_execution_feed import router as agent_execution_feed_router +from superagi.controllers.agent_execution_permission import router as agent_execution_permission_router from superagi.controllers.budget import router as budget_router from superagi.controllers.organisation import router as organisation_router from superagi.controllers.project import router as project_router @@ -83,8 +84,9 @@ app.include_router(agent_config_router, prefix="/agentconfigs") app.include_router(agent_execution_router, prefix="/agentexecutions") app.include_router(agent_execution_feed_router, prefix="/agentexecutionfeeds") +app.include_router(agent_execution_permission_router, prefix="/agentexecutionpermissions") app.include_router(resources_router, prefix="/resources") -app.include_router(config_router,prefix="/configs") +app.include_router(config_router, prefix="/configs") app.include_router(agent_template_router,prefix="/agent_templates") app.include_router(agent_workflow_router,prefix="/agent_workflows") diff --git a/migrations/versions/1d54db311055_add_permissions.py b/migrations/versions/1d54db311055_add_permissions.py new file mode 100644 index 000000000..9d63b2501 --- /dev/null +++ b/migrations/versions/1d54db311055_add_permissions.py @@ -0,0 +1,45 @@ +"""add permissions + +Revision ID: 1d54db311055 +Revises: 3356a2f89a33 +Create Date: 2023-06-14 11:05:59.678961 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1d54db311055' +down_revision = '516ecc1c723d' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('agent_execution_permissions', + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('agent_execution_id', sa.Integer(), nullable=True), + sa.Column('agent_id', sa.Integer(), nullable=True), + sa.Column('status', sa.String(), nullable=True), + sa.Column('tool_name', sa.String(), nullable=True), + sa.Column('user_feedback', sa.Text(), nullable=True), + sa.Column('assistant_reply', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.add_column('agent_executions', sa.Column('permission_id', sa.Integer(), nullable=True)) + # index on agent_execution_id + op.create_index(op.f('ix_agent_execution_permissions_agent_execution_id') + , 'agent_execution_permissions', ['agent_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agent_executions', 'permission_id') + op.drop_table('agent_execution_permissions') + # ### end Alembic commands ### diff --git a/superagi/agent/super_agi.py b/superagi/agent/super_agi.py index d9fe6817d..4be3271c6 100644 --- a/superagi/agent/super_agi.py +++ b/superagi/agent/super_agi.py @@ -23,6 +23,7 @@ from superagi.models.agent_execution import AgentExecution # from superagi.models.types.agent_with_config import AgentWithConfig from superagi.models.agent_execution_feed import AgentExecutionFeed +from superagi.models.agent_execution_permission import AgentExecutionPermission from superagi.models.agent_workflow_step import AgentWorkflowStep from superagi.models.db import connect_db from superagi.tools.base_tool import BaseTool @@ -112,7 +113,6 @@ def split_history(self, history: List, pending_token_limit: int) -> Tuple[List[B i -= 1 return [], history - def execute(self, workflow_step: AgentWorkflowStep): print(self.tools) @@ -121,7 +121,8 @@ def execute(self, workflow_step: AgentWorkflowStep): task_queue = TaskQueue(str(agent_execution_id)) token_limit = TokenCounter.token_limit() - agent_feeds = self.fetch_agent_feeds(session, self.agent_config["agent_execution_id"], self.agent_config["agent_id"]) + agent_feeds = self.fetch_agent_feeds(session, self.agent_config["agent_execution_id"], + self.agent_config["agent_id"]) current_calls = 0 if len(agent_feeds) <= 0: task_queue.clear_tasks() @@ -129,7 +130,8 @@ def execute(self, workflow_step: AgentWorkflowStep): max_token_limit = 600 # adding history to the messages if workflow_step.history_enabled: - prompt = self.build_agent_prompt(workflow_step.prompt, task_queue=task_queue, max_token_limit=max_token_limit) + prompt = self.build_agent_prompt(workflow_step.prompt, task_queue=task_queue, + max_token_limit=max_token_limit) messages.append({"role": "system", "content": prompt}) messages.append({"role": "system", "content": f"The current time and date is {time.strftime('%c')}"}) base_token_limit = TokenCounter.count_message_tokens(messages, self.llm.get_model()) @@ -140,7 +142,8 @@ def execute(self, workflow_step: AgentWorkflowStep): messages.append({"role": history["role"], "content": history["content"]}) messages.append({"role": "user", "content": workflow_step.completion_prompt}) else: - prompt = self.build_agent_prompt(workflow_step.prompt, task_queue=task_queue, max_token_limit=max_token_limit) + prompt = self.build_agent_prompt(workflow_step.prompt, task_queue=task_queue, + max_token_limit=max_token_limit) messages.append({"role": "system", "content": prompt}) # agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_config["agent_execution_id"], # agent_id=self.agent_config["agent_id"], feed=template_step.prompt, @@ -170,14 +173,22 @@ def execute(self, workflow_step: AgentWorkflowStep): if workflow_step.output_type == "tools": agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_config["agent_execution_id"], - agent_id=self.agent_config["agent_id"], feed=assistant_reply, - role="assistant") + agent_id=self.agent_config["agent_id"], feed=assistant_reply, + role="assistant") session.add(agent_execution_feed) session.commit() + + # check if permission is required for the tool in restricted mode + is_permission_required, response = self.check_permission_in_restricted_mode(assistant_reply) + if is_permission_required: + return response + tool_response = self.handle_tool_response(assistant_reply) agent_execution_feed = AgentExecutionFeed(agent_execution_id=self.agent_config["agent_execution_id"], - agent_id=self.agent_config["agent_id"], feed=tool_response["result"], - role="system") + agent_id=self.agent_config["agent_id"], + feed=tool_response["result"], + role="system" + ) session.add(agent_execution_feed) final_response = tool_response final_response["pending_task_count"] = len(task_queue.get_tasks()) @@ -274,7 +285,7 @@ def build_agent_prompt(self, prompt: str, task_queue: TaskQueue, max_token_limit add_finish_tool = True if len(pending_tasks) > 0 or len(completed_tasks) > 0: add_finish_tool = False - print(self.tools) + prompt = AgentPromptBuilder.replace_main_variables(prompt, self.agent_config["goal"], self.agent_config["instruction"], self.agent_config["constraints"], self.tools, add_finish_tool) @@ -291,4 +302,24 @@ def build_agent_prompt(self, prompt: str, task_queue: TaskQueue, max_token_limit token_limit = TokenCounter.token_limit() - max_token_limit prompt = AgentPromptBuilder.replace_task_based_variables(prompt, current_task, last_task, last_task_result, pending_tasks, completed_tasks, token_limit) - return prompt \ No newline at end of file + return prompt + + def check_permission_in_restricted_mode(self, assistant_reply: str): + action = self.output_parser.parse(assistant_reply) + tools = {t.name: t for t in self.tools} + + excluded_tools = [FINISH, '', None] + + if self.agent_config["permission_type"].upper() == "RESTRICTED" and action.name not in excluded_tools and \ + tools.get(action.name) and tools[action.name].permission_required: + new_agent_execution_permission = AgentExecutionPermission( + agent_execution_id=self.agent_config["agent_execution_id"], + status="PENDING", + agent_id=self.agent_config["agent_id"], + tool_name=action.name, + assistant_reply=assistant_reply) + + session.add(new_agent_execution_permission) + session.commit() + return True, {"result": "WAITING_FOR_PERMISSION", "permission_id": new_agent_execution_permission.id} + return False, None diff --git a/superagi/controllers/agent_execution_feed.py b/superagi/controllers/agent_execution_feed.py index 1d95b8a0f..758d0352d 100644 --- a/superagi/controllers/agent_execution_feed.py +++ b/superagi/controllers/agent_execution_feed.py @@ -7,6 +7,7 @@ from superagi.agent.task_queue import TaskQueue from superagi.helper.auth import check_auth +from superagi.models.agent_execution_permission import AgentExecutionPermission from superagi.helper.feed_parser import parse_feed from superagi.models.agent_execution import AgentExecution from superagi.models.agent_execution_feed import AgentExecutionFeed @@ -131,10 +132,27 @@ def get_agent_execution_feed(agent_execution_id: int, # # parse json final_feeds = [] for feed in feeds: - final_feeds.append(parse_feed(feed)) + if feed.feed != "": + final_feeds.append(parse_feed(feed)) + + # get all permissions + execution_permissions = db.session.query(AgentExecutionPermission).\ + filter_by(agent_execution_id=agent_execution_id, status="PENDING"). \ + order_by(asc(AgentExecutionPermission.created_at)).all() + + permissions = [ + { + "id": permission.id, + "created_at": permission.created_at, + "response": permission.user_feedback, + "status": permission.status, + "tool_name": permission.tool_name + } for permission in execution_permissions + ] return { "status": agent_execution.status, - "feeds": final_feeds + "feeds": final_feeds, + "permissions": permissions } @@ -150,7 +168,6 @@ def get_execution_tasks(agent_execution_id: int, Returns: dict: The tasks and completed tasks for the agent execution. """ - task_queue = TaskQueue(str(agent_execution_id)) tasks = [] for task in task_queue.get_tasks(): diff --git a/superagi/controllers/agent_execution_permission.py b/superagi/controllers/agent_execution_permission.py new file mode 100644 index 000000000..e7c4f9e3b --- /dev/null +++ b/superagi/controllers/agent_execution_permission.py @@ -0,0 +1,129 @@ +from datetime import datetime +from typing import Annotated + +from fastapi_sqlalchemy import db +from fastapi import HTTPException, Depends, Body +from fastapi_jwt_auth import AuthJWT + +from superagi.models.agent_execution_permission import AgentExecutionPermission +from superagi.worker import execute_agent +from fastapi import APIRouter +from pydantic_sqlalchemy import sqlalchemy_to_pydantic +from superagi.helper.auth import check_auth + +router = APIRouter() + + +@router.get("/get/{agent_execution_permission_id}") +def get_agent_execution_permission(agent_execution_permission_id: int, + Authorize: AuthJWT = Depends(check_auth)): + """ + Get an agent execution permission by its ID. + + Args: + agent_execution_permission_id (int): The ID of the agent execution permission. + Authorize (AuthJWT, optional): Authentication object. Defaults to Depends(check_auth). + + Raises: + HTTPException: If the agent execution permission is not found. + + Returns: + AgentExecutionPermission: The requested agent execution permission. + """ + + db_agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id) + if not db_agent_execution_permission: + raise HTTPException(status_code=404, detail="Agent execution permission not found") + return db_agent_execution_permission + + +@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecutionPermission)) +def create_agent_execution_permission( + agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"]) + , Authorize: AuthJWT = Depends(check_auth)): + """ + Create a new agent execution permission. + + Args: + agent_execution_permission : An instance of AgentExecutionPermission model as json. + Authorize (AuthJWT, optional): Authorization token, by default depends on the check_auth function. + + Returns: + new_agent_execution_permission: A newly created agent execution permission instance. + """ + new_agent_execution_permission = AgentExecutionPermission(**agent_execution_permission.dict()) + db.session.add(new_agent_execution_permission) + db.session.commit() + return new_agent_execution_permission + + +@router.patch("/update/{agent_execution_permission_id}", + response_model=sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"])) +def update_agent_execution_permission(agent_execution_permission_id: int, + agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission, + exclude=["id"]), + Authorize: AuthJWT = Depends(check_auth)): + """ + Update an AgentExecutionPermission in the database. + + Given an agent_execution_permission_id and the updated agent_execution_permission, this function updates the + corresponding AgentExecutionPermission in the database. If the AgentExecutionPermission is not found, an HTTPException + is raised. + + Args: + agent_execution_permission_id (int): The ID of the AgentExecutionPermission to update. + agent_execution_permission : The updated AgentExecutionPermission object as json. + Authorize (AuthJWT, optional): Dependency to authenticate the user. + + Returns: + db_agent_execution_permission (AgentExecutionPermission): The updated AgentExecutionPermission in the database. + + Raises: + HTTPException: If the AgentExecutionPermission is not found in the database. + """ + db_agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id) + if not db_agent_execution_permission: + raise HTTPException(status_code=404, detail="Agent execution permission not found") + + for key, value in agent_execution_permission.dict().items(): + setattr(db_agent_execution_permission, key, value) + + db.session.commit() + return db_agent_execution_permission + + +@router.put("/update/status/{agent_execution_permission_id}") +def update_agent_execution_permission_status(agent_execution_permission_id: int, + status: Annotated[bool, Body(embed=True)], + user_feedback: Annotated[str, Body(embed=True)] = "", + Authorize: AuthJWT = Depends(check_auth)): + """ + Update the execution permission status of an agent in the database. + + This function updates the execution permission status of an agent in the database. The status can be + either "APPROVED" or "REJECTED". The function also updates the user feedback if provided, + commits the changes to the database, and enqueues the agent for execution. + + :params: + - agent_execution_permission_id (int): The ID of the agent execution permission + - status (bool): The status of the agent execution permission, True for "APPROVED", False for "REJECTED" + - user_feedback (str): Optional user feedback on the status update + - Authorize (AuthJWT): Dependency function to check user authorization + + :return: + - A dictionary containing a "success" key with the value True to indicate a successful update. + """ + + agent_execution_permission = db.session.query(AgentExecutionPermission).get(agent_execution_permission_id) + print(agent_execution_permission) + if agent_execution_permission is None: + raise HTTPException(status_code=400, detail="Invalid Request") + if status is None: + raise HTTPException(status_code=400, detail="Invalid Request status is required") + agent_execution_permission.status = "APPROVED" if status else "REJECTED" + agent_execution_permission.user_feedback = user_feedback.strip() if len(user_feedback.strip()) > 0 else None + db.session.commit() + + execute_agent.delay(agent_execution_permission.agent_execution_id, datetime.now()) + + return {"success": True} diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index bcffaea92..468a14817 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -9,6 +9,8 @@ from superagi.llms.openai import OpenAi from superagi.models.agent import Agent from superagi.models.agent_execution import AgentExecution +from superagi.models.agent_execution_feed import AgentExecutionFeed +from superagi.models.agent_execution_permission import AgentExecutionPermission from superagi.models.agent_workflow_step import AgentWorkflowStep from superagi.models.configuration import Configuration from superagi.models.db import connect_db @@ -118,7 +120,7 @@ def execute_next_action(self, agent_execution_id): agent = session.query(Agent).filter(Agent.id == agent_execution.agent_id).first() # if agent_execution.status == "PAUSED" or agent_execution.status == "TERMINATED" or agent_execution == "COMPLETED": # return - if agent_execution.status != "RUNNING": + if agent_execution.status != "RUNNING" and agent_execution.status != "WAITING_FOR_PERMISSION": return if not agent: @@ -171,6 +173,11 @@ def execute_next_action(self, agent_execution_id): memory=memory, agent_config=parsed_config) + try: + self.handle_wait_for_permission(agent_execution, spawned_agent, session) + except ValueError: + return + agent_workflow_step = session.query(AgentWorkflowStep).filter( AgentWorkflowStep.id == agent_execution.current_step_id).first() response = spawned_agent.execute(agent_workflow_step) @@ -181,7 +188,11 @@ def execute_next_action(self, agent_execution_id): if response["result"] == "COMPLETE": db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first() db_agent_execution.status = "COMPLETED" - + session.commit() + elif response["result"] == "WAITING_FOR_PERMISSION": + db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first() + db_agent_execution.status = "WAITING_FOR_PERMISSION" + db_agent_execution.permission_id = response.get("permission_id", None) session.commit() else: logger.info("Starting next job for agent execution id: ", agent_execution_id) @@ -223,3 +234,36 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key tool.agent_id = agent_id new_tools.append(tool) return tools + + def handle_wait_for_permission(self, agent_execution, spawned_agent, session): + """ + Handles the wait for permission when the agent execution is waiting for permission. + + Args: + agent_execution (AgentExecution): The agent execution. + spawned_agent (SuperAgi): The spawned agent. + session (Session): The database session object. + + Raises: + ValueError: If the permission is still pending. + """ + if agent_execution.status != "WAITING_FOR_PERMISSION": + return + agent_execution_permission = session.query(AgentExecutionPermission).filter( + AgentExecutionPermission.id == agent_execution.permission_id).first() + if agent_execution_permission.status == "PENDING": + raise ValueError("Permission is still pending") + if agent_execution_permission.status == "APPROVED": + result = spawned_agent.handle_tool_response(agent_execution_permission.assistant_reply).get("result") + else: + result = f"User denied the permission to run the tool {agent_execution_permission.tool_name}" \ + f"{' and has given the following feedback : ' + agent_execution_permission.user_feedback if agent_execution_permission.user_feedback else ''}" + + agent_execution_feed = AgentExecutionFeed(agent_execution_id=agent_execution_permission.agent_execution_id, + agent_id=agent_execution_permission.agent_id, + feed=result, + role="system" + ) + session.add(agent_execution_feed) + agent_execution.status = "RUNNING" + session.commit() diff --git a/superagi/models/agent_execution.py b/superagi/models/agent_execution.py index 28cec530d..0419c322e 100644 --- a/superagi/models/agent_execution.py +++ b/superagi/models/agent_execution.py @@ -32,6 +32,7 @@ class AgentExecution(DBBaseModel): num_of_calls = Column(Integer, default=0) num_of_tokens = Column(Integer, default=0) current_step_id = Column(Integer) + permission_id = Column(Integer) def __repr__(self): """ diff --git a/superagi/models/agent_execution_feed.py b/superagi/models/agent_execution_feed.py index 07a548cc5..9a9cb94f1 100644 --- a/superagi/models/agent_execution_feed.py +++ b/superagi/models/agent_execution_feed.py @@ -35,4 +35,4 @@ def __repr__(self): return f"AgentExecutionFeed(id={self.id}, " \ f"agent_execution_id={self.agent_execution_id}, " \ - f"feed='{self.feed}', type='{self.type}', extra_info={self.extra_info})" + f"feed='{self.feed}', role='{self.role}', extra_info={self.extra_info})" diff --git a/superagi/models/agent_execution_permission.py b/superagi/models/agent_execution_permission.py new file mode 100644 index 000000000..32719a49c --- /dev/null +++ b/superagi/models/agent_execution_permission.py @@ -0,0 +1,41 @@ +from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from superagi.models.base_model import DBBaseModel +from superagi.models.agent_execution import AgentExecution + + +class AgentExecutionPermission(DBBaseModel): + """ + Represents an Agent Execution Permission record in the database. + + Attributes: + id (Integer): The primary key of the agent execution permission record. + agent_execution_id (Integer): The ID of the agent execution this permission record is associated with. + agent_id (Integer): The ID of the agent this permission record is associated with. + status (String): The status of the agent execution permission, APPROVED, REJECTED, or PENDING. + tool_name (String): The name of the tool or service that requires the permission. + user_feedback (Text): Any feedback provided by the user regarding the agent execution permission. + assistant_reply (Text): The reply or message sent back to the user by the assistant. + + Methods: + __repr__: Returns a string representation of the AgentExecutionPermission instance. + """ + __tablename__ = 'agent_execution_permissions' + + id = Column(Integer, primary_key=True) + agent_execution_id = Column(Integer) + agent_id = Column(Integer) + status = Column(String) + tool_name = Column(String) + user_feedback = Column(Text) + assistant_reply = Column(Text) + + def __repr__(self): + """ + Returns a string representation of the AgentExecutionPermission instance. + """ + return f"AgentExecutionPermission(id={self.id}, " \ + f"agent_execution_id={self.agent_execution_id}, " \ + f"agent_id={self.agent_id}, " \ + f"status={self.status}, " \ + f"response={self.user_feedback})" diff --git a/superagi/models/base_model.py b/superagi/models/base_model.py index 9ca0e3d2a..8b872b11b 100644 --- a/superagi/models/base_model.py +++ b/superagi/models/base_model.py @@ -1,3 +1,5 @@ +import json + from sqlalchemy import Column, DateTime, INTEGER from sqlalchemy.ext.declarative import declarative_base from datetime import datetime @@ -7,14 +9,60 @@ class DBBaseModel(Base): """ - Base model for SQLAlchemy models. + DBBaseModel is an abstract base class for all SQLAlchemy ORM models , + providing common columns and functionality. Attributes: - created_at (DateTime): The timestamp indicating the creation time of the record. - updated_at (DateTime): The timestamp indicating the last update time of the record. - """ + created_at: Datetime column to store the timestamp about when a row is created. + updated_at: Datetime column to store the timestamp about when a row is updated. + Methods: + to_dict: Converts the current object to a dictionary. + to_json: Converts the current object to a JSON string. + from_json: Creates a new object of the class using the provided JSON data. + __repr__: Returns a string representation of the current object. + """ __abstract__ = True # id = Column(INTEGER,primary_key=True,autoincrement=True) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + """ + Converts the current SQLAlchemy ORM object to a dictionary representation. + + Returns: + A dictionary mapping column names to their corresponding values. + """ + return {column.name: getattr(self, column.name) for column in self.__table__.columns} + + def to_json(self): + """ + Converts the current SQLAlchemy ORM object to a JSON string representation. + + Returns: + A JSON string representing the object with column names as keys and their corresponding values. + """ + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_data): + """ + Creates a new SQLAlchemy ORM object of the class using the provided JSON data. + + Args: json_data (str): A JSON string representing the object with column names as keys and their + corresponding values. + + Returns: + A new SQLAlchemy ORM object of the class. + """ + return cls(**json.loads(json_data)) + + def __repr__(self): + """ + Returns a string representation of the current SQLAlchemy ORM object. + + Returns: + A string with the format " ()". + """ + return f"{self.__class__.__name__} ({self.to_dict()})" diff --git a/superagi/models/db.py b/superagi/models/db.py index 2fadc0e9e..d49b5793e 100644 --- a/superagi/models/db.py +++ b/superagi/models/db.py @@ -37,5 +37,5 @@ def connect_db(): logger.info("Connected to the database! @ " + db_url) connection.close() except Exception as e: - logger.error("Unable to connect to the database:", e) + logger.error(f"Unable to connect to the database:{e}") return engine diff --git a/superagi/tools/base_tool.py b/superagi/tools/base_tool.py index 5554a0678..6e8c7677b 100644 --- a/superagi/tools/base_tool.py +++ b/superagi/tools/base_tool.py @@ -58,6 +58,7 @@ class BaseTool(BaseModel): name: str = None description: str args_schema: Type[BaseModel] = None + permission_required: bool = True @property def args(self): @@ -78,6 +79,7 @@ def _execute(self, *args: Any, **kwargs: Any): def max_token_limit(self): return get_config("MAX_TOOL_TOKEN_LIMIT", 600) + def _parse_input( self, tool_input: Union[str, Dict], diff --git a/superagi/tools/thinking/tools.py b/superagi/tools/thinking/tools.py index 58e123521..50c4c699a 100644 --- a/superagi/tools/thinking/tools.py +++ b/superagi/tools/thinking/tools.py @@ -35,6 +35,7 @@ class ThinkingTool(BaseTool): ) args_schema: Type[ThinkingSchema] = ThinkingSchema goals: List[str] = [] + permission_required: bool = False class Config: arbitrary_types_allowed = True diff --git a/superagi/worker.py b/superagi/worker.py index f1c31da23..c107da67e 100644 --- a/superagi/worker.py +++ b/superagi/worker.py @@ -4,7 +4,6 @@ from celery import Celery from superagi.config.config import get_config -from superagi.jobs.agent_executor import AgentExecutor redis_url = get_config('REDIS_URL') app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"]) @@ -15,5 +14,6 @@ @app.task(name="execute_agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5) def execute_agent(agent_execution_id: int, time): """Execute an agent step in background.""" + from superagi.jobs.agent_executor import AgentExecutor logger.info("Execute agent:" + str(time) + "," + str(agent_execution_id)) AgentExecutor().execute_next_action(agent_execution_id=agent_execution_id) diff --git a/tests/agent_permissions/__init__.py b/tests/agent_permissions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent_permissions/test_check_permission_in_restricted_mode.py b/tests/agent_permissions/test_check_permission_in_restricted_mode.py new file mode 100644 index 000000000..220ac868c --- /dev/null +++ b/tests/agent_permissions/test_check_permission_in_restricted_mode.py @@ -0,0 +1,61 @@ +import pytest +from unittest.mock import MagicMock, Mock +from superagi.agent.output_parser import AgentOutputParser +from superagi.agent.super_agi import SuperAgi +from superagi.llms.base_llm import BaseLlm +from superagi.tools.base_tool import BaseTool +from superagi.vector_store.base import VectorStore + + +class MockTool(BaseTool): + def __init__(self, name, permission_required=False): + super().__init__(name=name, permission_required=permission_required, description="Mock tool") + + def _execute(self, *args, **kwargs): + pass + +@pytest.fixture +def super_agi(): + ai_name = "test_ai" + ai_role = "test_role" + llm = Mock(spec=BaseLlm) + memory = Mock(spec=VectorStore) + tools = [MockTool(name="NotRestrictedTool", permission_required=False), + MockTool(name="RestrictedTool", permission_required=True)] + agent_config = {"permission_type": "RESTRICTED", "agent_execution_id": 1, "agent_id": 2} + output_parser = AgentOutputParser() + + super_agi = SuperAgi(ai_name, ai_role, llm, memory, tools, agent_config, output_parser) + return super_agi + + +def test_check_permission_in_restricted_mode_not_required(super_agi): + assistant_reply = "Test reply" + + super_agi.output_parser.parse = MagicMock( + return_value=MockTool(name="NotRestrictedTool", permission_required=False)) + result, output = super_agi.check_permission_in_restricted_mode(assistant_reply) + assert not result + assert output is None + + +def test_check_permission_in_restricted_mode_permission_required(super_agi, monkeypatch): + assistant_reply = "Test reply" + + mock_tool_requiring_permission = MockTool(name="RestrictedTool", permission_required=True) + mock_tool_requiring_permission.permission_required = True + super_agi.output_parser.parse = MagicMock( + return_value=mock_tool_requiring_permission) + + class MockSession: + def add(self, instance): + pass + + def commit(self): + pass + + monkeypatch.setattr("superagi.agent.super_agi.session", MockSession()) + + result, output = super_agi.check_permission_in_restricted_mode(assistant_reply) + assert result + assert output["result"] == "WAITING_FOR_PERMISSION" diff --git a/tests/agent_permissions/test_handle_wait_for_permission.py b/tests/agent_permissions/test_handle_wait_for_permission.py new file mode 100644 index 000000000..6b752ac7e --- /dev/null +++ b/tests/agent_permissions/test_handle_wait_for_permission.py @@ -0,0 +1,68 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +import superagi.models.agent_execution +from superagi.models.agent_execution_feed import AgentExecutionFeed +from superagi.models.agent_execution_permission import AgentExecutionPermission +from superagi.models.base_model import DBBaseModel as Base +from superagi.jobs.agent_executor import AgentExecutor +# Setup +engine = create_engine("sqlite:///:memory:") +Base.metadata.create_all(engine) +Session = sessionmaker(bind=engine) + +class StubSpawnedAgent: + def __init__(self): + self.handled_responses = [] + + def handle_tool_response(self, response): + self.handled_responses.append(response) + return {"result": response} + +def test_handle_wait_for_permission(): + # Setup an in-memory SQLite database + session = Session() + + # Add testing entities to the session + agent_execution = superagi.models.agent_execution.AgentExecution(status="WAITING_FOR_PERMISSION",) + session.add(agent_execution) + session.flush() + + permission_pending = AgentExecutionPermission(id=1, status="PENDING",agent_execution_id=agent_execution.id,tool_name="test_tool") + permission_approved = AgentExecutionPermission(id=2, status="APPROVED", assistant_reply="Approved",agent_execution_id=agent_execution.id,tool_name="test_tool") + permission_denied = AgentExecutionPermission(id=3, status="DENIED", user_feedback="Nope",agent_execution_id=agent_execution.id,tool_name="test_tool") + session.add_all([permission_pending, permission_approved, permission_denied]) + + spawned_agent = StubSpawnedAgent() # You should create this class as a test stub + session.flush() # Flush to get autogenerated ID + agent_execution_id = agent_execution.id + + # Test the pending case + agent_execution.permission_id = permission_pending.id + with pytest.raises(ValueError): + AgentExecutor().handle_wait_for_permission(agent_execution, spawned_agent, session) + + # Test the approved case + agent_execution.status = "WAITING_FOR_PERMISSION" + agent_execution.permission_id = permission_approved.id + AgentExecutor().handle_wait_for_permission(agent_execution, spawned_agent, session) + + agent_execution_feed = session.query(AgentExecutionFeed).filter( + AgentExecutionFeed.agent_execution_id == agent_execution_id).first() + print(agent_execution_feed) + assert agent_execution_feed is not None + assert agent_execution.status == "RUNNING" + assert agent_execution_feed.feed == "Approved" + + # Test the denied case + agent_execution.status = "WAITING_FOR_PERMISSION" + agent_execution.permission_id = permission_denied.id + AgentExecutor().handle_wait_for_permission(agent_execution, spawned_agent, session) + + agent_execution_feeds = session.query(AgentExecutionFeed).filter( + AgentExecutionFeed.agent_execution_id == agent_execution_id).all() + assert len(agent_execution_feeds) == 2 + assert agent_execution.status == "RUNNING" + assert agent_execution_feeds[-1].feed == "User denied the permission to run the tool test_tool" \ + " and has given the following feedback : Nope" \ No newline at end of file