diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..c13efebe9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,97 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python CI + +on: + push: + branches: [ "main", "dev" ] + pull_request: + branches: [ "main", "dev" ] + +permissions: + contents: read + +jobs: + lint: + + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + + - name: Set up Python 3.9 + uses: actions/setup-python@v3 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + test: + # eliminate duplicate runs + if: github.event_name == 'push' || (github.event.pull_request.head.repo.fork == (github.event_name == 'pull_request_target')) + + permissions: + # Gives the action the necessary permissions for publishing new + # comments in pull requests. + pull-requests: write + # Gives the action the necessary permissions for pushing data to the + # python-coverage-comment-action branch, and for editing existing + # comments (to avoid publishing multiple comments in the same PR) + contents: write + runs-on: ubuntu-latest + timeout-minutes: 30 + strategy: + matrix: + python-version: ["3.9"] + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + submodules: true + + - name: Configure git user SuperAGI-Bot + run: | + git config --global user.name "SuperAGI-Bot" + git config --global user.email "github-bot@superagi.com" + + - name: Set up Python 3.9 + uses: actions/setup-python@v3 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Test with pytest + run: | + pytest --cov=superagi --cov-branch --cov-report term-missing --cov-report xml \ + tests/unit_tests -s + env: + CI: true + ENV: DEV + PLAIN_OUTPUT: True + REDIS_URL: "localhost:6379" + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..46fd9ae14 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +repos: + - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] \ No newline at end of file diff --git a/gui/pages/Content/Agents/ActionConsole.js b/gui/pages/Content/Agents/ActionConsole.js index 4dc621bcf..984018dbe 100644 --- a/gui/pages/Content/Agents/ActionConsole.js +++ b/gui/pages/Content/Agents/ActionConsole.js @@ -1,116 +1,114 @@ import React, { useState, useEffect } from 'react'; import styles from './Agents.module.css'; -import Image from "next/image"; +import Image from 'next/image'; import { updatePermissions } from '@/pages/api/DashboardService'; +import { formatTime } from '@/utils/utils'; + +function ActionBox({ action, index, denied, reasons, handleDeny, handleSelection, setReasons }) { + const isDenied = denied[index]; + + return ( +
+
+
Tool {action.tool_name} is seeking for Permissions
+ {isDenied && ( +
+
Provide Feedback (Optional)
+ { + const newReasons = [...reasons]; + newReasons[index] = e.target.value; + setReasons(newReasons); + }}/> +
+ )} + {isDenied ? ( +
+ + +
+ ) : ( +
+ + +
+ )} +
+
+
+ schedule-icon +
+
{formatTime(action.created_at)}
+
+
+ ); +} export default function ActionConsole({ actions }) { const [hiddenActions, setHiddenActions] = useState([]); - const [reasons, setReasons] = useState(actions.map(() => '')); - const [localActions, setLocalActions] = useState(actions); const [denied, setDenied] = useState([]); + const [reasons, setReasons] = useState([]); const [localActionIds, setLocalActionIds] = useState([]); useEffect(() => { - const updatedActions = actions.filter( - (action) => !localActionIds.includes(action.id) - ); - - if (updatedActions.length > 0) { - setLocalActions( - localActions.map((localAction) => - updatedActions.find(({ id }) => id === localAction.id) || localAction - ) - ); - - const updatedDenied = updatedActions.map(() => false); - const updatedReasons = updatedActions.map(() => ''); + const updatedActions = actions?.filter((action) => !localActionIds.includes(action.id)); - setDenied((prev) => prev.map((value, index) => updatedDenied[index] || value)); - setReasons((prev) => prev.map((value, index) => updatedReasons[index] || value)); + if (updatedActions && updatedActions.length > 0) { + setLocalActionIds((prevIds) => [...prevIds, ...updatedActions.map(({ id }) => id)]); - setLocalActionIds([...localActionIds, ...updatedActions.map(({ id }) => id)]); + setDenied((prevDenied) => prevDenied.map((value, index) => updatedActions[index] ? false : value)); + setReasons((prevReasons) => prevReasons.map((value, index) => updatedActions[index] ? '' : value)); } }, [actions]); - const handleDeny = index => { - const newDeniedState = [...denied]; - newDeniedState[index] = !newDeniedState[index]; - setDenied(newDeniedState); - }; - - const formatDate = (dateString) => { - const now = new Date(); - const date = new Date(dateString); - const seconds = Math.floor((now - date) / 1000); - const minutes = Math.floor(seconds / 60); - const hours = Math.floor(minutes / 60); - const days = Math.floor(hours / 24); - const weeks = Math.floor(days / 7); - const months = Math.floor(days / 30); - const years = Math.floor(days / 365); - - if (years > 0) return `${years} yr${years === 1 ? '' : 's'}`; - if (months > 0) return `${months} mon${months === 1 ? '' : 's'}`; - if (weeks > 0) return `${weeks} wk${weeks === 1 ? '' : 's'}`; - if (days > 0) return `${days} day${days === 1 ? '' : 's'}`; - if (hours > 0) return `${hours} hr${hours === 1 ? '' : 's'}`; - if (minutes > 0) return `${minutes} min${minutes === 1 ? '' : 's'}`; - - return `${seconds} sec${seconds === 1 ? '' : 's'}`; + const handleDeny = (index) => { + setDenied((prevDenied) => { + const newDeniedState = [...prevDenied]; + newDeniedState[index] = !newDeniedState[index]; + return newDeniedState; + }); }; const handleSelection = (index, status, permissionId) => { - setHiddenActions([...hiddenActions, index]); + setHiddenActions((prevHiddenActions) => [...prevHiddenActions, index]); const data = { status: status, user_feedback: reasons[index], }; - updatePermissions(permissionId, data).then((response) => { - console.log("voila") - }); + updatePermissions(permissionId, data).then((response) => {}); }; return ( - <> - {actions.some(action => action.status === "PENDING") ? (
- {actions.map((action, index) => action.status === "PENDING" && !hiddenActions.includes(index) && ( -
-
-
Tool {action.tool_name} is seeking for Permissions
- {denied[index] && ( -
-
Provide Feedback (Optional)
- {const newReasons = [...reasons];newReasons[index] = e.target.value;setReasons(newReasons);}} placeholder="Enter your input here" className="input_medium" /> -
- )} - {denied[index] ? ( -
- - -
- ) : ( -
- - -
- )} -
-
-
- schedule-icon -
-
{formatDate(action.created_at)}
-
-
- ))} -
): - ( -
- no permissions - No Actions to Display! -
)} - + <> + {actions?.some((action) => action.status === 'PENDING') ? ( +
+ {actions.map((action, index) => { + if (action.status === 'PENDING' && !hiddenActions.includes(index)) { + return (); + } + return null; + })} +
+ ) : ( +
+ no-permissions + No Actions to Display! +
+ )} + ); -} \ No newline at end of file +} diff --git a/gui/pages/Content/Agents/ActivityFeed.js b/gui/pages/Content/Agents/ActivityFeed.js index f0ce5f5b0..f1a4d698d 100644 --- a/gui/pages/Content/Agents/ActivityFeed.js +++ b/gui/pages/Content/Agents/ActivityFeed.js @@ -2,7 +2,7 @@ import React, {useEffect, useRef, useState} from 'react'; import styles from './Agents.module.css'; import {getExecutionFeeds} from "@/pages/api/DashboardService"; import Image from "next/image"; -import {formatTime} from "@/utils/utils"; +import {formatTime, loadingTextEffect} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; export default function ActivityFeed({selectedRunId, selectedView, setFetchedData }) { @@ -13,15 +13,7 @@ export default function ActivityFeed({selectedRunId, selectedView, setFetchedDat const [prevFeedsLength, setPrevFeedsLength] = useState(0); useEffect(() => { - const text = 'Thinking'; - let dots = ''; - - const interval = setInterval(() => { - dots = dots.length < 3 ? dots + '.' : ''; - setLoadingText(`${text}${dots}`); - }, 250); - - return () => clearInterval(interval); + loadingTextEffect('Thinking', setLoadingText, 250); }, []); useEffect(() => { diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 75952212a..55a4560fd 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -685,7 +685,7 @@ export default function AgentCreate({sendAgentData, selectedProjectId, fetchAgen
setPermissionDropdown(!permissionDropdown)} style={{width:'100%'}}> {permission}expand-icon
-
+
{permissionDropdown &&
{permissions.map((permit, index) => (
handlePermissionSelect(index)} style={{padding:'12px 14px',maxWidth:'100%'}}> {permit} diff --git a/gui/pages/Content/Agents/AgentTemplatesList.js b/gui/pages/Content/Agents/AgentTemplatesList.js index 82783a2ec..f82156337 100644 --- a/gui/pages/Content/Agents/AgentTemplatesList.js +++ b/gui/pages/Content/Agents/AgentTemplatesList.js @@ -48,8 +48,8 @@ export default function AgentTemplatesList({sendAgentData, selectedProjectId, fe
- {agentTemplates.length > 0 ?
- {agentTemplates.map((item, index) => ( + {agentTemplates.length > 0 ?
+ {agentTemplates.map((item) => (
handleTemplateClick(item)}>
@@ -61,7 +61,7 @@ export default function AgentTemplatesList({sendAgentData, selectedProjectId, fe
-
+
arrow-outward  Browse templates from marketplace
arrow-outward
diff --git a/gui/pages/Content/Agents/AgentWorkspace.js b/gui/pages/Content/Agents/AgentWorkspace.js index a9ae07dd5..7e8942160 100644 --- a/gui/pages/Content/Agents/AgentWorkspace.js +++ b/gui/pages/Content/Agents/AgentWorkspace.js @@ -14,7 +14,7 @@ import {EventBus} from "@/utils/eventBus"; export default function AgentWorkspace({agentId, selectedView}) { const [leftPanel, setLeftPanel] = useState('activity_feed') - const [rightPanel, setRightPanel] = useState('details') + const [rightPanel, setRightPanel] = useState('') const [history, setHistory] = useState(true) const [selectedRun, setSelectedRun] = useState(null) const [runModal, setRunModal] = useState(false) @@ -30,16 +30,19 @@ export default function AgentWorkspace({agentId, selectedView}) { const addInstruction = () => { setInstructions((prevArray) => [...prevArray, 'new instructions']); }; + const handleInstructionDelete = (index) => { const updatedInstructions = [...instructions]; updatedInstructions.splice(index, 1); setInstructions(updatedInstructions); }; + const handleInstructionChange = (index, newValue) => { const updatedInstructions = [...instructions]; updatedInstructions[index] = newValue; setInstructions(updatedInstructions); }; + const addGoal = () => { setGoals((prevArray) => [...prevArray, 'new goal']); }; @@ -138,6 +141,12 @@ export default function AgentWorkspace({agentId, selectedView}) { fetchExecutions(agentId); }, [agentId]) + useEffect(() => { + if(agentDetails) { + setRightPanel(agentDetails.permission_type.includes('RESTRICTED') ? 'action_console' : 'details'); + } + }, [agentDetails]) + function fetchAgentDetails(agentId) { getAgentDetails(agentId) .then((response) => { @@ -145,7 +154,6 @@ export default function AgentWorkspace({agentId, selectedView}) { setTools(response.data.tools); setGoals(response.data.goal); setInstructions(response.data.instruction); - console.log(response.data) }) .catch((error) => { console.error('Error fetching agent details:', error); @@ -217,11 +225,7 @@ export default function AgentWorkspace({agentId, selectedView}) {
{leftPanel === 'activity_feed' &&
- +
} {leftPanel === 'agent_type' &&
}
diff --git a/gui/pages/Content/Agents/ResourceManager.js b/gui/pages/Content/Agents/ResourceManager.js index b773297ac..4eedfa1cb 100644 --- a/gui/pages/Content/Agents/ResourceManager.js +++ b/gui/pages/Content/Agents/ResourceManager.js @@ -119,10 +119,16 @@ export default function ResourceManager({agentId}) { }; const ResourceList = ({ files }) => ( -
- {files.map((file, index) => ( - - ))} +
+ {files.length <= 0 && channel === 'output' ?
+ no-permissions + No Output files! +
:
+ {files.map((file, index) => ( + + ))} +
+ }
); diff --git a/gui/pages/Content/Agents/TaskQueue.js b/gui/pages/Content/Agents/TaskQueue.js index 20d5b945c..d56ab8213 100644 --- a/gui/pages/Content/Agents/TaskQueue.js +++ b/gui/pages/Content/Agents/TaskQueue.js @@ -25,7 +25,11 @@ export default function TaskQueue({ selectedRunId }) { } return ( -
+ <> + {tasks.pending.length <= 0 && tasks.completed.length <= 0 ?
+ no-permissions + No Tasks found! +
:
{tasks.pending.length > 0 &&
Pending Tasks
} {tasks.pending.map((task, index) => (
@@ -49,6 +53,7 @@ export default function TaskQueue({ selectedRunId }) {
))} -
+
} + ); } diff --git a/gui/pages/Content/Marketplace/Market.js b/gui/pages/Content/Marketplace/Market.js index 105712b30..55666edf1 100644 --- a/gui/pages/Content/Marketplace/Market.js +++ b/gui/pages/Content/Marketplace/Market.js @@ -27,14 +27,17 @@ export default function Market() { useEffect(() => { const handleOpenTemplateDetails = (item) => { - setAgentTemplateData(item) - setItemClicked(true) + setAgentTemplateData(item); + setItemClicked(true); }; + const handleBackClick = ()=>{ - setItemClicked(false) + setItemClicked(false); } + EventBus.on('openTemplateDetails', handleOpenTemplateDetails); EventBus.on('goToMarketplace', handleBackClick); + return () => { EventBus.off('openTemplateDetails', handleOpenTemplateDetails); EventBus.off('goToMarketplace', handleBackClick); diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css index 012410e2d..e06755211 100644 --- a/gui/pages/Content/Marketplace/Market.module.css +++ b/gui/pages/Content/Marketplace/Market.module.css @@ -288,8 +288,9 @@ .resources { display: flex; - justify-content: space-between; + justify-content: flex-start; flex-wrap: wrap; + gap: 0.3vw; } .agent_resources { @@ -395,12 +396,6 @@ height: 100%; } -.resources { - display: flex; - justify-content: space-between; - flex-wrap: wrap; -} - .top_heading { font-style: normal; font-weight: 400; @@ -444,15 +439,16 @@ } .vertical_line{ - width: 0px; + width: 0; height: 20px; border: 1px solid rgba(255, 255, 255, 0.1); flex: none; margin-left:8px; } + .topbar_heading{ font-style: normal; - font-weight: 300; + font-weight: 500; font-size: 14px; line-height: 18px; color: #FFFFFF; @@ -481,13 +477,15 @@ color: #888888; margin-left: 4px; } + .marketplace_public_button{ display: flex; justify-content: flex-end; align-items: center; width:100%; - padding-right:24px; + padding-right:8px; } + .marketplace_public_content{ height:92.5vh; width:99vw; @@ -495,6 +493,7 @@ margin-left:8px; border-radius: 8px } + .marketplace_public_container{ height:6.5vh; display:flex; diff --git a/gui/pages/Content/Marketplace/MarketAgent.js b/gui/pages/Content/Marketplace/MarketAgent.js index 8c6189442..38f29bfb1 100644 --- a/gui/pages/Content/Marketplace/MarketAgent.js +++ b/gui/pages/Content/Marketplace/MarketAgent.js @@ -3,6 +3,7 @@ import Image from "next/image"; import styles from './Market.module.css'; import {fetchAgentTemplateList} from "@/pages/api/DashboardService"; import {EventBus} from "@/utils/eventBus"; +import {loadingTextEffect} from "@/utils/utils"; export default function MarketAgent(){ const [agentTemplates, setAgentTemplates] = useState([]) @@ -10,31 +11,22 @@ export default function MarketAgent(){ const [isLoading, setIsLoading] = useState(true) const [loadingText, setLoadingText] = useState("Loading Templates"); - useEffect(() => { - const text = 'Loading Templates'; - let dots = ''; - - const interval = setInterval(() => { - dots = dots.length < 3 ? dots + '.' : ''; - setLoadingText(`${text}${dots}`); - }, 500); + useEffect(() => { + loadingTextEffect('Loading Templates', setLoadingText, 500); - return () => clearInterval(interval); - }, []); + if(window.location.href.toLowerCase().includes('marketplace')) { + setShowMarketplace(true) + } - useEffect(() => { - if(window.location.href.toLowerCase().includes('marketplace')) { - setShowMarketplace(true) - } - fetchAgentTemplateList() - .then((response) => { - const data = response.data || []; - setAgentTemplates(data); - setIsLoading(false); - }) - .catch((error) => { - console.error('Error fetching agent templates:', error); - }); + fetchAgentTemplateList() + .then((response) => { + const data = response.data || []; + setAgentTemplates(data); + setIsLoading(false); + }) + .catch((error) => { + console.error('Error fetching agent templates:', error); + }); }, []); function handleTemplateClick(item) { @@ -44,7 +36,7 @@ export default function MarketAgent(){ return (
- {!isLoading ?
+ {!isLoading ?
{agentTemplates.map((item, index) => (
handleTemplateClick(item)}>
diff --git a/gui/pages/Content/Marketplace/MarketplacePublic.js b/gui/pages/Content/Marketplace/MarketplacePublic.js index d45d9e657..b226b4e20 100644 --- a/gui/pages/Content/Marketplace/MarketplacePublic.js +++ b/gui/pages/Content/Marketplace/MarketplacePublic.js @@ -3,31 +3,31 @@ import Image from "next/image"; import styles from './Market.module.css'; import Market from './Market'; -export default function MarketplacePublic() { - function handleSignupClick() { - if (window.location.href.toLowerCase().includes('localhost')) { - window.location.href = '/'; - } - else - window.open(`https://app.superagi.com/`, '_self') +export default function MarketplacePublic({env}) { + const handleSignupClick = () => { + if (env === 'PROD') { + window.open(`https://app.superagi.com/`, '_self'); + } else { + window.location.href = '/'; } + }; - return ( -
-
-
super-agi-logo -
-
 marketplace
-
-
- -
-
-
- -
-
- ); + return ( +
+
+
super-agi-logo +
+
 marketplace
+
+
+ +
+
+
+ +
+
+ ); }; diff --git a/gui/pages/_app.css b/gui/pages/_app.css index dce203026..ae6a2970e 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -210,14 +210,14 @@ input[type="range"]::-moz-range-track { .dropdown_container { width: 150px; height: fit-content; - background: #1B192C; + background: #2E293F; display: flex; flex-direction: column; justify-content: center; position: absolute; border-radius: 8px; - box-shadow: -3px 3px 8px rgba(0, 0, 0, 0.3); - padding: 4px 5px; + box-shadow: -2px 2px 24px rgba(0, 0, 0, 0.4); + padding: 5px; } .dropdown_item { diff --git a/gui/pages/_app.js b/gui/pages/_app.js index 4ef6f27b4..d54ba0c22 100644 --- a/gui/pages/_app.js +++ b/gui/pages/_app.js @@ -10,7 +10,7 @@ import { getOrganisation, getProject, validateAccessToken, checkEnvironment, add import { githubClientId } from "@/pages/api/apiConfig"; import { useRouter } from 'next/router'; import querystring from 'querystring'; -import {refreshUrl} from "@/utils/utils"; +import {refreshUrl, loadingTextEffect} from "@/utils/utils"; import MarketplacePublic from "./Content/Marketplace/MarketplacePublic" export default function App() { @@ -24,18 +24,6 @@ export default function App() { const router = useRouter(); const [showMarketplace, setShowMarketplace] = useState(false); - useEffect(() => { - const text = 'Initializing SuperAGI'; - let dots = ''; - - const interval = setInterval(() => { - dots = dots.length < 3 ? dots + '.' : ''; - setLoadingText(`${text}${dots}`); - }, 500); - - return () => clearInterval(interval); - }, []); - function fetchOrganisation(userId) { getOrganisation(userId) .then((response) => { @@ -47,6 +35,12 @@ export default function App() { } useEffect(() => { + if(window.location.href.toLowerCase().includes('marketplace')) { + setShowMarketplace(true); + } + + loadingTextEffect('Initializing SuperAGI', setLoadingText, 500); + checkEnvironment() .then((response) => { const env = response.data.env; @@ -114,12 +108,6 @@ export default function App() { setApplicationState("AUTHENTICATED"); } }, [selectedProject]); - - useEffect(() => { - if(window.location.href.toLowerCase().includes('marketplace')) { - setShowMarketplace(true) - } - }, []); const handleSelectionEvent = (data) => { setSelectedView(data); @@ -139,7 +127,7 @@ export default function App() { {/* eslint-disable-next-line @next/next/no-page-custom-font */} - {showMarketplace &&
} + {showMarketplace &&
} {applicationState === 'AUTHENTICATED' && !showMarketplace ? (
diff --git a/gui/utils/utils.js b/gui/utils/utils.js index 9cabc3936..86dccb2d5 100644 --- a/gui/utils/utils.js +++ b/gui/utils/utils.js @@ -79,4 +79,16 @@ export const refreshUrl = () => { const urlWithoutToken = window.location.origin + window.location.pathname; window.history.replaceState({}, document.title, urlWithoutToken); -}; \ No newline at end of file +}; + +export const loadingTextEffect = (loadingText, setLoadingText, timer) => { + const text = loadingText; + let dots = ''; + + const interval = setInterval(() => { + dots = dots.length < 3 ? dots + '.' : ''; + setLoadingText(`${text}${dots}`); + }, timer); + + return () => clearInterval(interval) +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 71d7a6dd8..345d4f791 100644 --- a/requirements.txt +++ b/requirements.txt @@ -132,3 +132,7 @@ tiktoken==0.4.0 psycopg2==2.9.6 slack-sdk==3.21.3 pytest==7.3.2 +pylint==2.17.4 +pre-commit==3.3.3 +pytest-cov==4.1.0 +pytest-mock==3.11.1 diff --git a/superagi/config/config.py b/superagi/config/config.py index 54067960f..421592845 100644 --- a/superagi/config/config.py +++ b/superagi/config/config.py @@ -25,17 +25,7 @@ def load_config(cls, config_file: str) -> dict: logger.info("\033[91m\033[1m" + "\nConfig file not found. Enter required keys and values." + "\033[0m\033[0m") - config_data = { - "PINECONE_API_KEY": input("Pinecone API Key: "), - "PINECONE_ENVIRONMENT": input("Pinecone Environment: "), - # "OPENAI_API_KEY": input("OpenAI API Key: "), - "GOOGLE_API_KEY": input("Google API Key: "), - "SEARCH_ENGINE_ID": input("Search Engine ID: "), - "RESOURCES_ROOT_DIR": input( - "Resources Root Directory (default: /tmp/): " - ) - or "/tmp/", - } + config_data = {} with open(config_file, "w") as file: yaml.dump(config_data, file, default_flow_style=False) diff --git a/superagi/helper/resource_helper.py b/superagi/helper/resource_helper.py index 025e2e3b2..0449827f0 100644 --- a/superagi/helper/resource_helper.py +++ b/superagi/helper/resource_helper.py @@ -8,20 +8,19 @@ class ResourceHelper: @staticmethod - def make_written_file_resource(file_name: str, agent_id: int, file, channel): + def make_written_file_resource(file_name: str, agent_id: int, channel: str): """ Function to create a Resource object for a written file. Args: file_name (str): The name of the file. agent_id (int): The ID of the agent. - file (FileStorage): The file. channel (str): The channel of the file. Returns: Resource: The Resource object. """ - path = get_config("RESOURCES_OUTPUT_ROOT_DIR") + path = ResourceHelper.get_root_dir() storage_type = get_config("STORAGE_TYPE") file_extension = os.path.splitext(file_name)[1][1:] @@ -32,29 +31,59 @@ def make_written_file_resource(file_name: str, agent_id: int, file, channel): else: file_type = "application/misc" - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name + if agent_id is not None: + final_path = ResourceHelper.get_agent_resource_path(file_name, agent_id) + path = path + str(agent_id) + "/" else: - final_path = os.getcwd() + "/" + file_name - + final_path = ResourceHelper.get_resource_path(file_name) file_size = os.path.getsize(final_path) if storage_type == "S3": file_name_parts = file_name.split('.') - file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '').replace('.', '').replace( - ':', '') + '.' + file_name_parts[1] - if channel == "INPUT": - path = 'input' - else: - path = 'output' - - logger.info(path + "/" + file_name) - resource = Resource(name=file_name, path=path + "/" + file_name, storage_type=storage_type, size=file_size, + file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '') \ + .replace('.', '').replace(':', '') + '.' + file_name_parts[1] + path = 'input/' if (channel == "INPUT") else 'output/' + + logger.info(final_path) + resource = Resource(name=file_name, path=path + file_name, storage_type=storage_type, size=file_size, type=file_type, channel="OUTPUT", agent_id=agent_id) return resource + + @staticmethod + def get_resource_path(file_name: str): + """Get final path of the resource. + + Args: + file_name (str): The name of the file. + """ + return ResourceHelper.get_root_dir() + file_name + + @staticmethod + def get_root_dir(): + """Get root dir of the resource. + """ + root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + else: + root_dir = os.getcwd() + "/" + return root_dir + + @staticmethod + def get_agent_resource_path(file_name: str, agent_id: int): + """Get agent resource path + + Args: + file_name (str): The name of the file. + """ + root_dir = ResourceHelper.get_root_dir() + if agent_id is not None: + directory = os.path.dirname(root_dir + str(agent_id) + "/") + os.makedirs(directory, exist_ok=True) + root_dir = root_dir + str(agent_id) + "/" + final_path = root_dir + file_name + return final_path diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 636be0aea..8253b89dc 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -17,7 +17,9 @@ from superagi.models.organisation import Organisation from superagi.models.project import Project from superagi.models.tool import Tool +from superagi.resource_manager.manager import ResourceManager from superagi.tools.thinking.tools import ThinkingTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager from superagi.vector_store.embedding.openai import OpenAiEmbedding from superagi.vector_store.vector_factory import VectorFactory from superagi.helper.encyption_helper import decrypt_data @@ -164,7 +166,7 @@ def execute_next_action(self, agent_execution_id): print(user_tools) tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id, - model_api_key=model_api_key) + model_api_key=model_api_key, session=session) @@ -205,7 +207,7 @@ def execute_next_action(self, agent_execution_id): # finally: engine.dispose() - def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key): + def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key, session): """ Set the default parameters for the tools. @@ -232,6 +234,12 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key tool.image_llm = OpenAi(model=parsed_config["model"], api_key=model_api_key) if hasattr(tool, 'agent_id'): tool.agent_id = agent_id + if hasattr(tool, 'resource_manager'): + tool.resource_manager = ResourceManager(session=session, agent_id=agent_id) + if hasattr(tool, 'tool_response_manager'): + tool.tool_response_manager = ToolResponseQueryManager(session=session, agent_execution_id=parsed_config[ + "agent_execution_id"]) + new_tools.append(tool) return tools diff --git a/superagi/models/agent_execution_feed.py b/superagi/models/agent_execution_feed.py index 9a9cb94f1..98acf5d26 100644 --- a/superagi/models/agent_execution_feed.py +++ b/superagi/models/agent_execution_feed.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, Text, String +from sqlalchemy.orm import Session from superagi.models.base_model import DBBaseModel @@ -36,3 +37,16 @@ def __repr__(self): return f"AgentExecutionFeed(id={self.id}, " \ f"agent_execution_id={self.agent_execution_id}, " \ f"feed='{self.feed}', role='{self.role}', extra_info={self.extra_info})" + + @classmethod + def get_last_tool_response(cls, session: Session, agent_execution_id: int, tool_name: str = None): + agent_execution_feeds = session.query(AgentExecutionFeed).filter( + AgentExecutionFeed.agent_execution_id == agent_execution_id, + AgentExecutionFeed.role == "system").order_by(AgentExecutionFeed.created_at.desc()).all() + + for agent_execution_feed in agent_execution_feeds: + if tool_name and not agent_execution_feed.feed.startswith("Tool " + tool_name): + continue + if agent_execution_feed.feed.startswith("Tool"): + return agent_execution_feed.feed + return "" diff --git a/tests/agent/__init__.py b/superagi/resource_manager/__init__.py similarity index 100% rename from tests/agent/__init__.py rename to superagi/resource_manager/__init__.py diff --git a/superagi/resource_manager/manager.py b/superagi/resource_manager/manager.py new file mode 100644 index 000000000..565aa6276 --- /dev/null +++ b/superagi/resource_manager/manager.py @@ -0,0 +1,59 @@ +from sqlalchemy.orm import Session + +from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger +import os + + +class ResourceManager: + def __init__(self, session: Session, agent_id: int = None): + self.session = session + self.agent_id = agent_id + + def write_binary_file(self, file_name: str, data): + if self.agent_id is not None: + final_path = ResourceHelper.get_agent_resource_path(file_name, self.agent_id) + else: + final_path = ResourceHelper.get_resource_path(file_name) + + # if self.agent_id is not None: + # directory = os.path.dirname(final_path + "/" + str(self.agent_id) + "/") + # os.makedirs(directory, exist_ok=True) + try: + with open(final_path, mode="wb") as img: + img.write(data) + img.close() + self.write_to_s3(file_name, final_path) + logger.info(f"Binary {file_name} saved successfully") + return f"Binary {file_name} saved successfully" + except Exception as err: + return f"Error: {err}" + + def write_to_s3(self, file_name, final_path): + with open(final_path, 'rb') as img: + resource = ResourceHelper.make_written_file_resource(file_name=file_name, + agent_id=self.agent_id, channel="OUTPUT") + if resource is not None: + self.session.add(resource) + self.session.commit() + self.session.flush() + if resource.storage_type == "S3": + s3_helper = S3Helper() + s3_helper.upload_file(img, path=resource.path) + + def write_file(self, file_name: str, content): + if self.agent_id is not None: + final_path = ResourceHelper.get_agent_resource_path(file_name, self.agent_id) + else: + final_path = ResourceHelper.get_resource_path(file_name) + + try: + with open(final_path, mode="w") as file: + file.write(content) + file.close() + self.write_to_s3(file_name, final_path) + logger.info(f"{file_name} saved successfully") + return f"{file_name} saved successfully" + except Exception as err: + return f"Error: {err}" diff --git a/superagi/tools/code/tools.py b/superagi/tools/code/tools.py deleted file mode 100644 index ca9b8ee29..000000000 --- a/superagi/tools/code/tools.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Type, Optional, List - -from pydantic import BaseModel, Field - -from superagi.agent.agent_prompt_builder import AgentPromptBuilder -from superagi.llms.base_llm import BaseLlm -from superagi.tools.base_tool import BaseTool -from superagi.lib.logger import logger - - -class CodingSchema(BaseModel): - task_description: str = Field( - ..., - description="Coding task description.", - ) - -class CodingTool(BaseTool): - """ - Used to generate code. - - Attributes: - llm: LLM used for code generation. - name : The name of tool. - description : The description of tool. - args_schema : The args schema. - goals : The goals. - """ - llm: Optional[BaseLlm] = None - name = "CodingTool" - description = ( - "Useful for writing, reviewing, and refactoring code. Can also fix bugs and explain programming concepts." - ) - args_schema: Type[CodingSchema] = CodingSchema - goals: List[str] = [] - - class Config: - arbitrary_types_allowed = True - - - def _execute(self, task_description: str): - """ - Execute the code tool. - - Args: - task_description : The task description. - - Returns: - Generated code or error message. - """ - try: - prompt = """You're a top-notch coder, knowing all programming languages, software systems, and architecture. - - Your high level goal is: - {goals} - - Provide no information about who you are and focus on writing code. - Ensure code is bug and error free and explain complex concepts through comments - Respond in well-formatted markdown. Ensure code blocks are used for code sections. - - Write code to accomplish the following: - {task} - """ - prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) - prompt = prompt.replace("{task}", task_description) - messages = [{"role": "system", "content": prompt}] - result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit) - return result["content"] - except Exception as e: - logger.error(e) - return f"Error generating text: {e}" \ No newline at end of file diff --git a/superagi/tools/code/write_code.py b/superagi/tools/code/write_code.py new file mode 100644 index 000000000..d322fc2d1 --- /dev/null +++ b/superagi/tools/code/write_code.py @@ -0,0 +1,143 @@ +import re +from typing import Type, Optional, List + +from pydantic import BaseModel, Field + +from superagi.agent.agent_prompt_builder import AgentPromptBuilder +from superagi.helper.token_counter import TokenCounter +from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.base_tool import BaseTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager + + +class CodingSchema(BaseModel): + code_description: str = Field( + ..., + description="Description of the coding task", + ) + +class CodingTool(BaseTool): + """ + Used to generate code. + + Attributes: + llm: LLM used for code generation. + name : The name of tool. + description : The description of tool. + args_schema : The args schema. + goals : The goals. + resource_manager: Manages the file resources + """ + llm: Optional[BaseLlm] = None + agent_id: int = None + name = "CodingTool" + description = ( + "You will get instructions for code to write. You will write a very long answer. " + "Make sure that every detail of the architecture is, in the end, implemented as code. " + "Think step by step and reason yourself to the right decisions to make sure we get it right. " + "You will first lay out the names of the core classes, functions, methods that will be necessary, " + "as well as a quick comment on their purpose. Then you will output the content of each file including ALL code." + ) + args_schema: Type[CodingSchema] = CodingSchema + goals: List[str] = [] + resource_manager: Optional[ResourceManager] = None + tool_response_manager: Optional[ToolResponseQueryManager] = None + + class Config: + arbitrary_types_allowed = True + + + def _execute(self, code_description: str) -> str: + """ + Execute the write_code tool. + + Args: + code_description : The coding task description. + code_file_name: The name of the file where the generated codes will be saved. + + Returns: + Generated codes files or error message. + """ + try: + prompt = """You are a super smart developer who practices good Development for writing code according to a specification. + + Your high-level goal is: + {goals} + + Coding task description: + {code_description} + + {spec} + + You will get instructions for code to write. + You need to write a detailed answer. Make sure all parts of the architecture are turned into code. + Think carefully about each step and make good choices to get it right. First, list the main classes, + functions, methods you'll use and a quick comment on their purpose. + + Then you will output the content of each file including ALL code. + Each file must strictly follow a markdown code block format, where the following tokens must be replaced such that + [FILENAME] is the lowercase file name including the file extension, + [LANG] is the markup code block language for the code's language, and [CODE] is the code: + [FILENAME] + ```[LANG] + [CODE] + ``` + + You will start with the "entrypoint" file, then go to the ones that are imported by that file, and so on. + Please note that the code should be fully functional. No placeholders. + + Follow a language and framework appropriate best practice file naming convention. + Make sure that files contain all imports, types etc. Make sure that code in different files are compatible with each other. + Ensure to implement all code, if you are unsure, write a plausible implementation. + Include module dependency or package manager dependency definition file. + Before you finish, double check that all parts of the architecture is present in the files. + """ + prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) + prompt = prompt.replace("{code_description}", code_description) + spec_response = self.tool_response_manager.get_last_response("WriteSpecTool") + if spec_response != "": + prompt = prompt.replace("{spec}", "Use this specs for generating the code:\n" + spec_response) + logger.info(prompt) + messages = [{"role": "system", "content": prompt}] + + total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) + token_limit = TokenCounter.token_limit(self.llm.get_model()) + result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100)) + + # Get all filenames and corresponding code blocks + regex = r"(\S+?)\n```\S*\n(.+?)```" + matches = re.finditer(regex, result["content"], re.DOTALL) + + file_names = [] + # Save each file + + for match in matches: + # Get the filename + file_name = re.sub(r'[<>"|?*]', "", match.group(1)) + + # Get the code + code = match.group(2) + + # Ensure file_name is not empty + if not file_name.strip(): + continue + + file_names.append(file_name) + save_result = self.resource_manager.write_file(file_name, code) + if save_result.startswith("Error"): + return save_result + + # Get README contents and save + split_result = result["content"].split("```") + if len(split_result) > 0: + readme = split_result[0] + save_readme_result = self.resource_manager.write_file("README.md", readme) + if save_readme_result.startswith("Error"): + return save_readme_result + + return result["content"] + "\n Codes generated and saved successfully in " + ", ".join(file_names) + except Exception as e: + logger.error(e) + return f"Error generating codes: {e}" diff --git a/superagi/tools/code/write_spec.py b/superagi/tools/code/write_spec.py new file mode 100644 index 000000000..f15c89b15 --- /dev/null +++ b/superagi/tools/code/write_spec.py @@ -0,0 +1,97 @@ +from typing import Type, Optional, List + +from pydantic import BaseModel, Field +from superagi.config.config import get_config +from superagi.agent.agent_prompt_builder import AgentPromptBuilder +import os + +from superagi.helper.token_counter import TokenCounter +from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.base_tool import BaseTool +from superagi.lib.logger import logger +from superagi.models.db import connect_db +from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper +from sqlalchemy.orm import sessionmaker + + +class WriteSpecSchema(BaseModel): + task_description: str = Field( + ..., + description="Specification task description.", + ) + + spec_file_name: str = Field( + ..., + description="Name of the file to write. Only include the file name. Don't include path." + ) + +class WriteSpecTool(BaseTool): + """ + Used to generate program specification. + + Attributes: + llm: LLM used for specification generation. + name : The name of tool. + description : The description of tool. + args_schema : The args schema. + goals : The goals. + resource_manager: Manages the file resources + """ + llm: Optional[BaseLlm] = None + agent_id: int = None + name = "WriteSpecTool" + description = ( + "A tool to write the spec of a program." + ) + args_schema: Type[WriteSpecSchema] = WriteSpecSchema + goals: List[str] = [] + resource_manager: Optional[ResourceManager] = None + + class Config: + arbitrary_types_allowed = True + + def _execute(self, task_description: str, spec_file_name: str) -> str: + """ + Execute the write_spec tool. + + Args: + task_description : The task description. + spec_file_name: The name of the file where the generated specification will be saved. + + Returns: + Generated specification or error message. + """ + try: + prompt = """You are a super smart developer who has been asked to make a specification for a program. + + Your high-level goal is: + {goals} + + Please keep in mind the following when creating the specification: + 1. Be super explicit about what the program should do, which features it should have, and give details about anything that might be unclear. + 2. Lay out the names of the core classes, functions, methods that will be necessary, as well as a quick comment on their purpose. + 3. List all non-standard dependencies that will have to be used. + + Write a specification for the following task: + {task} + """ + prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) + prompt = prompt.replace("{task}", task_description) + messages = [{"role": "system", "content": prompt}] + + total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) + token_limit = TokenCounter.token_limit(self.llm.get_model()) + result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100)) + + # Save the specification to a file + write_result = self.resource_manager.write_file(spec_file_name, result["content"]) + if not write_result.startswith("Error"): + return result["content"] + "Specification generated and saved successfully" + else: + return write_result + + except Exception as e: + logger.error(e) + return f"Error generating specification: {e}" \ No newline at end of file diff --git a/superagi/tools/code/write_test.py b/superagi/tools/code/write_test.py new file mode 100644 index 000000000..0060b2cb3 --- /dev/null +++ b/superagi/tools/code/write_test.py @@ -0,0 +1,110 @@ +import re +from typing import Type, Optional, List + +from pydantic import BaseModel, Field + +from superagi.agent.agent_prompt_builder import AgentPromptBuilder +from superagi.helper.token_counter import TokenCounter +from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.base_tool import BaseTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager + + +class WriteTestSchema(BaseModel): + test_description: str = Field( + ..., + description="Description of the testing task", + ) + test_file_name: str = Field( + ..., + description="Name of the file to write. Only include the file name. Don't include path." + ) + + +class WriteTestTool(BaseTool): + """ + Used to generate pytest unit tests based on the specification. + + Attributes: + llm: LLM used for test generation. + name : The name of tool. + description : The description of tool. + args_schema : The args schema. + goals : The goals. + resource_manager: Manages the file resources + """ + llm: Optional[BaseLlm] = None + agent_id: int = None + name = "WriteTestTool" + description = ( + "You are a super smart developer using Test Driven Development to write tests according to a specification.\n" + "Please generate tests based on the above specification. The tests should be as simple as possible, " + "but still cover all the functionality.\n" + "Write it in the file" + ) + args_schema: Type[WriteTestSchema] = WriteTestSchema + goals: List[str] = [] + resource_manager: Optional[ResourceManager] = None + tool_response_manager: Optional[ToolResponseQueryManager] = None + + + class Config: + arbitrary_types_allowed = True + + def _execute(self, test_description: str, test_file_name: str) -> str: + """ + Execute the write_test tool. + + Args: + test_description : The specification description. + test_file_name: The name of the file where the generated tests will be saved. + + Returns: + Generated pytest unit tests or error message. + """ + try: + prompt = """You are a super smart developer who practices Test Driven Development for writing tests according to a specification. + + Your high-level goal is: + {goals} + + Test Description: + {test_description} + + {spec} + + The tests should be as simple as possible, but still cover all the functionality described in the specification. + """ + prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) + prompt = prompt.replace("{test_description}", test_description) + + spec_response = self.tool_response_manager.get_last_response("WriteSpecTool") + if spec_response != "": + prompt = prompt.replace("{spec}", "Please generate unit tests based on the following specification description:\n" + spec_response) + + messages = [{"role": "system", "content": prompt}] + logger.info(prompt) + + total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) + token_limit = TokenCounter.token_limit(self.llm.get_model()) + result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100)) + + # Extract the code part using regular expression + code = re.search(r'(?<=```).*?(?=```)', result["content"], re.DOTALL) + if code: + code_content = code.group(0).strip() + else: + return "Unable to extract code from the response" + + # Save the tests to a file + save_result = self.resource_manager.write_file(test_file_name, code_content) + if not save_result.startswith("Error"): + return result["content"] + " \n Tests generated and saved successfully in " + test_file_name + else: + return save_result + + except Exception as e: + logger.error(e) + return f"Error generating tests: {e}" \ No newline at end of file diff --git a/tests/agent_permissions/__init__.py b/superagi/tools/email/__init__.py similarity index 100% rename from tests/agent_permissions/__init__.py rename to superagi/tools/email/__init__.py diff --git a/superagi/tools/email/send_email.py b/superagi/tools/email/send_email.py index 2d6641cf4..d475e2a71 100644 --- a/superagi/tools/email/send_email.py +++ b/superagi/tools/email/send_email.py @@ -57,7 +57,7 @@ def _execute(self, to: str, subject: str, body: str) -> str: body += f"\n{signature}" message.set_content(body) draft_folder = get_config('EMAIL_DRAFT_MODE_WITH_FOLDER') - send_to_draft = draft_folder is not None or draft_folder != "YOUR_DRAFTS_FOLDER" + send_to_draft = draft_folder is not None and draft_folder != "YOUR_DRAFTS_FOLDER" if message["To"] == "example@example.com" or send_to_draft: conn = ImapEmail().imap_open(draft_folder, email_sender, email_password) conn.append( diff --git a/superagi/tools/file/append_file.py b/superagi/tools/file/append_file.py index b2f62d30a..72f3fae5f 100644 --- a/superagi/tools/file/append_file.py +++ b/superagi/tools/file/append_file.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from superagi.config.config import get_config +from superagi.helper.resource_helper import ResourceHelper from superagi.tools.base_tool import BaseTool @@ -38,13 +39,7 @@ def _execute(self, file_name: str, content: str): Returns: file written to successfully. or error message. """ - final_path = file_name - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + final_path = ResourceHelper.get_resource_path(file_name) try: directory = os.path.dirname(final_path) os.makedirs(directory, exist_ok=True) diff --git a/superagi/tools/file/delete_file.py b/superagi/tools/file/delete_file.py index cba875cb5..3917f0a1a 100644 --- a/superagi/tools/file/delete_file.py +++ b/superagi/tools/file/delete_file.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field +from superagi.helper.resource_helper import ResourceHelper from superagi.tools.base_tool import BaseTool from superagi.config.config import get_config @@ -36,13 +37,7 @@ def _execute(self, file_name: str, content: str): Returns: file deleted successfully. or error message. """ - final_path = file_name - root_dir = get_config('RESOURCES_INPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + final_path = ResourceHelper.get_resource_path(file_name) try: os.remove(final_path) return "File deleted successfully." diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py index d11247e06..f425a3186 100644 --- a/superagi/tools/file/write_file.py +++ b/superagi/tools/file/write_file.py @@ -1,16 +1,12 @@ -import os -from typing import Type +from typing import Type, Optional + from pydantic import BaseModel, Field + +from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -from sqlalchemy.orm import sessionmaker -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -# from superagi.helper.s3_helper import upload_to_s3 -from superagi.helper.s3_helper import S3Helper -from superagi.lib.logger import logger +# from superagi.helper.s3_helper import upload_to_s3 class WriteFileInput(BaseModel): @@ -32,6 +28,10 @@ class WriteFileTool(BaseTool): args_schema: Type[BaseModel] = WriteFileInput description: str = "Writes text to a file" agent_id: int = None + resource_manager: Optional[ResourceManager] = None + + class Config: + arbitrary_types_allowed = True def _execute(self, file_name: str, content: str): """ @@ -44,35 +44,5 @@ def _execute(self, file_name: str, content: str): Returns: file written to successfully. or error message. """ - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() - - final_path = file_name - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + file_name - else: - final_path = os.getcwd() + "/" + file_name + self.resource_manager.write_file(file_name, content) - try: - with open(final_path, 'w', encoding="utf-8") as file: - file.write(content) - file.close() - with open(final_path, 'rb') as file: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id,file=file,channel="OUTPUT") - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(file, path=resource.path) - logger.info("Resource Uploaded to S3!") - session.close() - return f"File written to successfully - {file_name}" - except Exception as err: - return f"Error: {err}" diff --git a/superagi/tools/image_generation/README.STABLE_DIFFUSION.md b/superagi/tools/image_generation/README.STABLE_DIFFUSION.md new file mode 100644 index 000000000..f4181544b --- /dev/null +++ b/superagi/tools/image_generation/README.STABLE_DIFFUSION.md @@ -0,0 +1,54 @@ +

+ + + +

+ +## SuperAGI Stable Diffusion Toolkit + +Introducing Stable Diffusion Integration with SuperAGI + +You can now use SuperAGI to summon Stable Diffusion to create true-to-life images and opens up a whole new range of possibilities. + +# ⚙️ Installation + +## 🛠️ Setting up SuperAGI + +Set-up SuperAGI by following the instruction given [here](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD) + +## 🔧Configuring API from DreamStudio + +You can now get your API Key from Dream Studio to use Stable Diffusion by following the instructions below: + +1. Create an Account/Login with [DreamStudio.ai](http://DreamStudio.ai) + +![SD_1](README/SD_1.jpg) + +1. Click on the Profile Icon at the top right which will take you to the settings page. Once you have reached the settings page, you can now get your API keys + +![SD_2](README/SD_2.jpg) + +1. Copy the API Key and save it in a separate file. + +## 🛠️Configuring Stable Diffusion with SuperAGI + +You can configure SuperAGI with Stable Diffusion using the following steps: + +1. Navigate to the “****************Toolkit”**************** Page in SuperAGI’s Dashboard and select “****************Image Generation Toolkit”**************** + +![SD_3](README/SD_3.jpg) + +1. Once you’ve clicked Image Generation Toolkit, it will open a page asking you for the API Key and the Model Engine. You can enter the generated API key from Dream Studio here. + +![SD_4](README/SD_4.jpg) +3. If you would like to get more in-depth with the model of Stable Diffusion you’d like to use, you can choose between the following engine IDs: + +- 'stable-diffusion-v1' +- 'stable-diffusion-v1-5' +- 'stable-diffusion-512-v2-0' +- 'stable-diffusion-768-v2-0' +- 'stable-diffusion-512-v2-1' +- ’stable-diffusion-768-v2-1' +- 'stable-diffusion-xl-beta-v2-2-2’ + +You have now successfully configured Stable Diffusion with SuperAGI! \ No newline at end of file diff --git a/superagi/tools/image_generation/README/SD_1.jpg b/superagi/tools/image_generation/README/SD_1.jpg new file mode 100644 index 000000000..becead11b Binary files /dev/null and b/superagi/tools/image_generation/README/SD_1.jpg differ diff --git a/superagi/tools/image_generation/README/SD_2.jpg b/superagi/tools/image_generation/README/SD_2.jpg new file mode 100644 index 000000000..01378871a Binary files /dev/null and b/superagi/tools/image_generation/README/SD_2.jpg differ diff --git a/superagi/tools/image_generation/README/SD_3.jpg b/superagi/tools/image_generation/README/SD_3.jpg new file mode 100644 index 000000000..8e1a6450e Binary files /dev/null and b/superagi/tools/image_generation/README/SD_3.jpg differ diff --git a/superagi/tools/image_generation/README/SD_4.jpg b/superagi/tools/image_generation/README/SD_4.jpg new file mode 100644 index 000000000..67a3e1c9b Binary files /dev/null and b/superagi/tools/image_generation/README/SD_4.jpg differ diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py index cd80847f1..2b120efc2 100644 --- a/superagi/tools/image_generation/dalle_image_gen.py +++ b/superagi/tools/image_generation/dalle_image_gen.py @@ -1,44 +1,42 @@ from typing import Type, Optional + +import requests from pydantic import BaseModel, Field + from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -import os -import requests -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -from superagi.helper.s3_helper import S3Helper -from sqlalchemy.orm import sessionmaker -from superagi.lib.logger import logger - - -class ImageGenInput(BaseModel): +class DalleImageGenInput(BaseModel): prompt: str = Field(..., description="Prompt for Image Generation to be used by Dalle.") size: int = Field(..., description="Size of the image to be Generated. default size is 512") num: int = Field(..., description="Number of Images to be generated. default num is 2") - image_name: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") + image_names: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") -class ImageGenTool(BaseTool): +class DalleImageGenTool(BaseTool): """ Dalle Image Generation tool Attributes: - name : The name. - description : The description. - args_schema : The args schema. + name : Name of the tool + description : The description + args_schema : The args schema + llm : The llm + agent_id : The agent id + resource_manager : Manages the file resources """ - name: str = "Dalle Image Generation" - args_schema: Type[BaseModel] = ImageGenInput + name: str = "DalleImageGeneration" + args_schema: Type[BaseModel] = DalleImageGenInput description: str = "Generate Images using Dalle" llm: Optional[BaseLlm] = None agent_id: int = None + resource_manager: Optional[ResourceManager] = None class Config: arbitrary_types_allowed = True - def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2): + def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2): """ Execute the Dalle Image Generation tool. @@ -46,47 +44,17 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2) prompt : The prompt for image generation. size : The size of the image to be generated. num : The number of images to be generated. - image_name (list): The name of the image to be generated. + image_names (list): The name of the image to be generated. Returns: Image generated successfully. or error message. """ - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() if size not in [256, 512, 1024]: size = min([256, 512, 1024], key=lambda x: abs(x - size)) response = self.llm.generate_image(prompt, size, num) response = response.__dict__ response = response['_previous']['data'] for i in range(num): - image = image_name[i] - final_path = image - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + image - else: - final_path = os.getcwd() + "/" + image - url = response[i]['url'] - data = requests.get(url).content - try: - with open(final_path, mode="wb") as img: - img.write(data) - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=image_name[i], - agent_id=self.agent_id, file=img,channel="OUTPUT") - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) - logger.info(f"Image {image} saved successfully") - except Exception as err: - session.close() - return f"Error: {err}" - session.close() + data = requests.get(response[i]['url']).content + self.resource_manager.write_binary_file(image_names[i], data) return "Images downloaded successfully" diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py index be16014ad..3f615d650 100644 --- a/superagi/tools/image_generation/stable_diffusion_image_gen.py +++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py @@ -1,17 +1,13 @@ +import base64 +from io import BytesIO from typing import Type, Optional + +import requests +from PIL import Image from pydantic import BaseModel, Field -from superagi.tools.base_tool import BaseTool from superagi.config.config import get_config -import os -from PIL import Image -from io import BytesIO -import requests -import base64 -from superagi.models.db import connect_db -from superagi.helper.resource_helper import ResourceHelper -from superagi.helper.s3_helper import S3Helper -from sqlalchemy.orm import sessionmaker -from superagi.lib.logger import logger +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.base_tool import BaseTool class StableDiffusionImageGenInput(BaseModel): @@ -20,22 +16,32 @@ class StableDiffusionImageGenInput(BaseModel): width: int = Field(..., description="Width of the image to be Generated. default width is 512") num: int = Field(..., description="Number of Images to be generated. default num is 2") steps: int = Field(..., description="Number of diffusion steps to run. default steps are 50") - image_name: list = Field(..., - description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") + image_names: list = Field(..., + description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.") class StableDiffusionImageGenTool(BaseTool): + """ + Stable diffusion Image Generation tool + + Attributes: + name : Name of the tool + description : The description + args_schema : The args schema + agent_id : The agent id + resource_manager : Manages the file resources + """ name: str = "Stable Diffusion Image Generation" args_schema: Type[BaseModel] = StableDiffusionImageGenInput description: str = "Generate Images using Stable Diffusion" agent_id: int = None + resource_manager: Optional[ResourceManager] = None - def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2, - steps: int = 50): - engine = connect_db() - Session = sessionmaker(bind=engine) - session = Session() + class Config: + arbitrary_types_allowed = True + def _execute(self, prompt: str, image_names: list, width: int = 512, height: int = 512, num: int = 2, + steps: int = 50): api_key = get_config("STABILITY_API_KEY") if api_key is None: @@ -58,21 +64,11 @@ def _execute(self, prompt: str, image_name: list, width: int = 512, height: int img_data = base64.b64decode(image_base64) final_img = Image.open(BytesIO(img_data)) image_format = final_img.format + img_byte_arr = BytesIO() + final_img.save(img_byte_arr, format=image_format) - image = image_name[i] - root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') - - final_path = self.build_file_path(image, root_dir) + self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue()) - try: - self.upload_to_s3(final_img, final_path, image_format, image_name[i], session) - - logger.info(f"Image {image} saved successfully") - except Exception as err: - session.close() - print(f"Error in _execute: {err}") - return f"Error: {err}" - session.close() return "Images downloaded and saved successfully" def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): @@ -90,11 +86,7 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): "Authorization": f"Bearer {api_key}" }, json={ - "text_prompts": [ - { - "text": prompt - } - ], + "text_prompts": [{"text": prompt}], "height": height, "width": width, "samples": num, @@ -102,27 +94,3 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps): }, ) return response - - def upload_to_s3(self, final_img, final_path, image_format, file_name, session): - with open(final_path, mode="wb") as img: - final_img.save(img, format=image_format) - with open(final_path, 'rb') as img: - resource = ResourceHelper.make_written_file_resource(file_name=file_name, - agent_id=self.agent_id, file=img, channel="OUTPUT") - logger.info(resource) - if resource is not None: - session.add(resource) - session.commit() - session.flush() - if resource.storage_type == "S3": - s3_helper = S3Helper() - s3_helper.upload_file(img, path=resource.path) - - def build_file_path(self, image, root_dir): - if root_dir is not None: - root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir - root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" - final_path = root_dir + image - else: - final_path = os.getcwd() + "/" + image - return final_path diff --git a/superagi/tools/slack/README.md b/superagi/tools/slack/README.md new file mode 100644 index 000000000..bfb5bb4de --- /dev/null +++ b/superagi/tools/slack/README.md @@ -0,0 +1,57 @@ +

+ +

+ +# SuperAGI Slack Toolkit + +This SuperAGI Tool lets users send messages to Slack Channels and provides a strong foundation for use cases to come. + +**Features:** + +1. Send Message - This tool gives SuperAGI the ability to send messages to Slack Channels that you have specified. + +## 🛠️ Installation + +Setting up of SuperAGI: + +Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD) + +### 🔧 **Slack Configuration:** + +1. Create an Application on SlackAPI Portal + + ![Slack_1](/README/Slack_1.jpg) + +2. Select "from scratch" + + ![Slack_2](README/Slack_2.jpg) + +3. Add your application's name and the workspace for which you'd like to use your Slack Application on + + ![Slack_3](README/Slack_3.jpg) + +4. Once the app creation process is done, head to the "OAuth and Permissions" tab + + ![Slack_4](README/Slack_4.jpg) + +5. Find the “**bot token scopes”** and define the following scopes: + + **"chat:write",**  and save it + + ![Slack_5](README/Slack_5.jpg) + +6. Once you've defined the scope, install the application to your workspace. + + + ![Slack_6](README/Slack_6.jpg) + +7. Post installation, you will get the bot token code + + + ![Slack_7](README/Slack_7.jpg) + +8. Once the installation is done, you'll get the Bot User OAuth Token, which needs to be added in the config.yaml beside the **"slack_bot_token"** variable. + +![Slack_8](README/Slack_8.jpg) + +Once the configuration is complete, you can install the app in the channel of your choice and create an agent on SuperAGI which can now send messages to the Slack Channel! \ No newline at end of file diff --git a/superagi/tools/slack/README/Slack_1.jpg b/superagi/tools/slack/README/Slack_1.jpg new file mode 100644 index 000000000..8d37a189a Binary files /dev/null and b/superagi/tools/slack/README/Slack_1.jpg differ diff --git a/superagi/tools/slack/README/Slack_2.jpg b/superagi/tools/slack/README/Slack_2.jpg new file mode 100644 index 000000000..e9f4d392f Binary files /dev/null and b/superagi/tools/slack/README/Slack_2.jpg differ diff --git a/superagi/tools/slack/README/Slack_3.jpg b/superagi/tools/slack/README/Slack_3.jpg new file mode 100644 index 000000000..23b057a9e Binary files /dev/null and b/superagi/tools/slack/README/Slack_3.jpg differ diff --git a/superagi/tools/slack/README/Slack_4.jpg b/superagi/tools/slack/README/Slack_4.jpg new file mode 100644 index 000000000..35694a79d Binary files /dev/null and b/superagi/tools/slack/README/Slack_4.jpg differ diff --git a/superagi/tools/slack/README/Slack_5.jpg b/superagi/tools/slack/README/Slack_5.jpg new file mode 100644 index 000000000..d1b1acd96 Binary files /dev/null and b/superagi/tools/slack/README/Slack_5.jpg differ diff --git a/superagi/tools/slack/README/Slack_6.jpg b/superagi/tools/slack/README/Slack_6.jpg new file mode 100644 index 000000000..2b361b479 Binary files /dev/null and b/superagi/tools/slack/README/Slack_6.jpg differ diff --git a/superagi/tools/slack/README/Slack_7.jpg b/superagi/tools/slack/README/Slack_7.jpg new file mode 100644 index 000000000..c9305134c Binary files /dev/null and b/superagi/tools/slack/README/Slack_7.jpg differ diff --git a/superagi/tools/slack/README/Slack_8.jpg b/superagi/tools/slack/README/Slack_8.jpg new file mode 100644 index 000000000..3016938df Binary files /dev/null and b/superagi/tools/slack/README/Slack_8.jpg differ diff --git a/superagi/tools/thinking/tools.py b/superagi/tools/thinking/tools.py index 50c4c699a..302b3374d 100644 --- a/superagi/tools/thinking/tools.py +++ b/superagi/tools/thinking/tools.py @@ -1,15 +1,12 @@ -import os -import openai from typing import Type, Optional, List from pydantic import BaseModel, Field from superagi.agent.agent_prompt_builder import AgentPromptBuilder -from superagi.tools.base_tool import BaseTool -from superagi.config.config import get_config -from superagi.llms.base_llm import BaseLlm -from pydantic import BaseModel, Field, PrivateAttr from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.tools.base_tool import BaseTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager class ThinkingSchema(BaseModel): @@ -36,6 +33,7 @@ class ThinkingTool(BaseTool): args_schema: Type[ThinkingSchema] = ThinkingSchema goals: List[str] = [] permission_required: bool = False + tool_response_manager: Optional[ToolResponseQueryManager] = None class Config: arbitrary_types_allowed = True @@ -58,13 +56,17 @@ def _execute(self, task_description: str): and the following task, `{task_description}`. + Below is last tool response: + `{last_tool_response}` + Perform the task by understanding the problem, extracting variables, and being smart and efficient. Provide a descriptive response, make decisions yourself when confronted with choices and provide reasoning for ideas / decisions. """ prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) prompt = prompt.replace("{task_description}", task_description) - + last_tool_response = self.tool_response_manager.get_last_response() + prompt = prompt.replace("{last_tool_response}", last_tool_response) messages = [{"role": "system", "content": prompt}] result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit) return result["content"] diff --git a/superagi/tools/tool_response_query_manager.py b/superagi/tools/tool_response_query_manager.py new file mode 100644 index 000000000..5a2387648 --- /dev/null +++ b/superagi/tools/tool_response_query_manager.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import Session + +from superagi.models.agent_execution_feed import AgentExecutionFeed + + +class ToolResponseQueryManager: + def __init__(self, session: Session, agent_execution_id: int): + self.session = session + self.agent_execution_id = agent_execution_id + + def get_last_response(self, tool_name: str = None): + return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name) diff --git a/superagi/worker.py b/superagi/worker.py index c107da67e..243eda64d 100644 --- a/superagi/worker.py +++ b/superagi/worker.py @@ -4,7 +4,7 @@ from celery import Celery from superagi.config.config import get_config -redis_url = get_config('REDIS_URL') +redis_url = get_config('REDIS_URL') or 'localhost:6379' app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"]) app.conf.broker_url = "redis://" + redis_url + "/0" diff --git a/test.py b/test.py index f8591dc46..eb466739b 100644 --- a/test.py +++ b/test.py @@ -38,23 +38,23 @@ def ask_user_for_goals(): return goals - -def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None): +def run_superagi_cli(agent_name=None, agent_description=None, agent_goals=None): # Create default organization organization = Organisation(name='Default Organization', description='Default organization description') session.add(organization) session.flush() # Flush pending changes to generate the agent's ID session.commit() logger.info(organization) - + # Create default project associated with the organization - project = Project(name='Default Project', description='Default project description', organisation_id=organization.id) + project = Project(name='Default Project', description='Default project description', + organisation_id=organization.id) session.add(project) session.flush() # Flush pending changes to generate the agent's ID session.commit() logger.info(project) - #Agent + # Agent if agent_name is None: agent_name = input("Enter agent name: ") if agent_description is None: @@ -65,24 +65,24 @@ def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None): session.commit() logger.info(agent) - #Agent Config + # Agent Config # Create Agent Configuration agent_config_values = { "goal": ask_user_for_goals() if agent_goals is None else agent_goals, "agent_type": "Type Non-Queue", - "constraints": [ "~4000 word limit for short term memory. ", - "Your short term memory is short, so immediately save important information to files.", - "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.", - "No user assistance", - "Exclusively use the commands listed in double quotes e.g. \"command name\"" - ], + "constraints": ["~4000 word limit for short term memory. ", + "Your short term memory is short, so immediately save important information to files.", + "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.", + "No user assistance", + "Exclusively use the commands listed in double quotes e.g. \"command name\"" + ], "tools": [], "exit": "Default", "iteration_interval": 0, "model": "gpt-4", "permission_type": "Default", "LTM_DB": "Pinecone", - "memory_window":10 + "memory_window": 10 } # print("Id is ") @@ -106,5 +106,6 @@ def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None): logger.info(execution) execute_agent.delay(execution.id, datetime.now()) - -run_superagi_cli(agent_name=agent_name,agent_description=agent_description,agent_goals=agent_goals) \ No newline at end of file + + +run_superagi_cli(agent_name=agent_name, agent_description=agent_description, agent_goals=agent_goals) diff --git a/tests/helper/__init__.py b/tests/integration_tests/__init__.py similarity index 100% rename from tests/helper/__init__.py rename to tests/integration_tests/__init__.py diff --git a/tests/integration_tests/vector_store/__init__.py b/tests/integration_tests/vector_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/vector_store/test_weaviate.py b/tests/integration_tests/vector_store/test_weaviate.py similarity index 100% rename from tests/vector_store/test_weaviate.py rename to tests/integration_tests/vector_store/test_weaviate.py diff --git a/tests/tools/email/__init__.py b/tests/tools/email/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/email/test_send_email.py b/tests/tools/email/test_send_email.py new file mode 100644 index 000000000..16d802477 --- /dev/null +++ b/tests/tools/email/test_send_email.py @@ -0,0 +1,70 @@ +from unittest.mock import MagicMock + +import pytest +import imaplib +import time +from email.message import EmailMessage + +from superagi.config.config import get_config +from superagi.helper.imap_email import ImapEmail +from superagi.tools.email import send_email +from superagi.tools.email.send_email import SendEmailTool + +def test_send_to_draft(mocker): + + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config', autospec=True) + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "Draft", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('mukunda@contlo.com', 'Test Subject', 'Test Body') + + # Assert the return value + assert result == 'Email went to Draft' + +def test_send_to_mailbox(mocker): + # Mocking the get_config calls + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config') + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "YOUR_DRAFTS_FOLDER", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + # mock_get_config.return_value = 'True' + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('test_receiver@test.com', 'Test Subject', 'Test Body') + + # Assert that the ImapEmail was not called (no draft mode) + mock_imap_email.assert_not_called() + + # Assert the return value + assert result == 'Email was sent to test_receiver@test.com' \ No newline at end of file diff --git a/tests/tools/image_gen_test.py b/tests/tools/image_gen_test.py deleted file mode 100644 index f96454545..000000000 --- a/tests/tools/image_gen_test.py +++ /dev/null @@ -1,46 +0,0 @@ -import os -import unittest -from unittest.mock import patch, MagicMock - -from superagi.tools.image_generation.dalle_image_gen import ImageGenTool - - -class TestImageGenTool(unittest.TestCase): - - @patch('openai.Image.create') - @patch('requests.get') - @patch('superagi.tools.image_generation.dalle_image_gen.get_config') - def test_image_gen_tool_execute(self, mock_get_config, mock_requests_get, mock_openai_create): - # Setup - tool = ImageGenTool() - prompt = 'Artificial Intelligence' - image_names = ['image1.png', 'image2.png'] - size = 512 - num = 2 - - # Mock responses - mock_get_config.return_value = "/tmp" - mock_openai_create.return_value = MagicMock(_previous=MagicMock(data=[ - {"url": "https://example.com/image1.png"}, - {"url": "https://example.com/image2.png"} - ])) - mock_requests_get.return_value.content = b"image_data" - - # Run the method under test - response = tool._execute(prompt, image_names, size, num) - - # Assert the method ran correctly - self.assertEqual(response, "Images downloaded successfully") - for image_name in image_names: - path = "/tmp/" + image_name - self.assertTrue(os.path.exists(path)) - with open(path, "rb") as file: - self.assertEqual(file.read(), b"image_data") - - # Clean up - for image_name in image_names: - os.remove("/tmp/" + image_name) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests/tools/stable_diffusion_image_gen_test.py b/tests/tools/stable_diffusion_image_gen_test.py deleted file mode 100644 index ae2ebc673..000000000 --- a/tests/tools/stable_diffusion_image_gen_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import unittest -from unittest.mock import patch, MagicMock -from PIL import Image -from io import BytesIO -import base64 -from superagi.config.config import get_config - -from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool - - -class TestStableDiffusionImageGenTool(unittest.TestCase): - - @patch('requests.post') - @patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config') - def test_stable_diffusion_image_gen_tool_execute(self, mock_get_config, mock_requests_post): - # Setup - tool = StableDiffusionImageGenTool() - prompt = 'Artificial Intelligence' - image_names = ['image1.png', 'image2.png'] - height = 512 - width = 512 - num = 2 - steps = 50 - - # Create a temporary directory for image storage - temp_dir = get_config("RESOURCES_OUTPUT_ROOT_DIR") - - # Mock responses - mock_configs = {"STABILITY_API_KEY": "api_key", "ENGINE_ID": "engine_id", "RESOURCES_OUTPUT_ROOT_DIR": temp_dir} - mock_get_config.side_effect = lambda k: mock_configs[k] - - # Prepare sample image bytes - img = Image.new("RGB", (width, height), "white") - buffer = BytesIO() - img.save(buffer, "PNG") - buffer.seek(0) - img_data = buffer.getvalue() - encoded_image_data = base64.b64encode(img_data).decode() - - # Use the proper base64-encoded string - mock_requests_post.return_value = MagicMock(status_code=200, json=lambda: { - "artifacts": [ - {"base64": encoded_image_data}, - {"base64": encoded_image_data} - ] - }) - - # Run the method under test - response = tool._execute(prompt, image_names, width, height, num, steps) - self.assertEqual(response, f"Images downloaded successfully") - - for image_name in image_names: - path = os.path.join(temp_dir, image_name) - self.assertTrue(os.path.exists(path)) - with open(path, "rb") as file: - self.assertEqual(file.read(), img_data) - - # Clean up - for image_name in image_names: - os.remove(os.path.join(temp_dir, image_name)) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/agent/__init__.py b/tests/unit_tests/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent/test_task_queue.py b/tests/unit_tests/agent/test_task_queue.py similarity index 98% rename from tests/agent/test_task_queue.py rename to tests/unit_tests/agent/test_task_queue.py index 9627231be..85dd20f31 100644 --- a/tests/agent/test_task_queue.py +++ b/tests/unit_tests/agent/test_task_queue.py @@ -47,5 +47,6 @@ def test_get_last_task_details(self, mock_get_last_task_details): self.queue.get_last_task_details() mock_get_last_task_details.assert_called() + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/unit_tests/agent_permissions/__init__.py b/tests/unit_tests/agent_permissions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent_permissions/test_check_permission_in_restricted_mode.py b/tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py similarity index 100% rename from tests/agent_permissions/test_check_permission_in_restricted_mode.py rename to tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py diff --git a/tests/agent_permissions/test_handle_wait_for_permission.py b/tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py similarity index 100% rename from tests/agent_permissions/test_handle_wait_for_permission.py rename to tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py diff --git a/tests/unit_tests/helper/__init__.py b/tests/unit_tests/helper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/helper/test_github_helper.py b/tests/unit_tests/helper/test_github_helper.py similarity index 100% rename from tests/helper/test_github_helper.py rename to tests/unit_tests/helper/test_github_helper.py diff --git a/tests/helper/test_json_cleaner.py b/tests/unit_tests/helper/test_json_cleaner.py similarity index 97% rename from tests/helper/test_json_cleaner.py rename to tests/unit_tests/helper/test_json_cleaner.py index be64eaa6f..8579a9900 100644 --- a/tests/helper/test_json_cleaner.py +++ b/tests/unit_tests/helper/test_json_cleaner.py @@ -40,4 +40,4 @@ def test_clean_newline_spaces_json(): def test_has_newline_in_string(): test_str = r'{key: "value\n"\n \n}' result = JsonCleaner.check_and_clean_json(test_str) - assert result == '{key: "value\\n"}' + assert result == '{key: "value"}' diff --git a/tests/unit_tests/helper/test_resource_helper.py b/tests/unit_tests/helper/test_resource_helper.py new file mode 100644 index 000000000..6c24cecb2 --- /dev/null +++ b/tests/unit_tests/helper/test_resource_helper.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import patch +from superagi.helper.resource_helper import ResourceHelper + +def test_make_written_file_resource(mocker): + mocker.patch('os.getcwd', return_value='/') + # mocker.patch('os.getcwd', return_value='/') + mocker.patch('os.makedirs', return_value=None) + mocker.patch('os.path.getsize', return_value=1000) + mocker.patch('os.path.splitext', return_value=("", ".txt")) + mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/', 'local', None]) + + with patch('superagi.helper.resource_helper.logger') as logger_mock: + result = ResourceHelper.make_written_file_resource('test.txt', 1, 'INPUT') + + assert result.name == 'test.txt' + assert result.path == '/1/test.txt' + assert result.storage_type == 'local' + assert result.size == 1000 + assert result.type == 'application/txt' + assert result.channel == 'OUTPUT' + assert result.agent_id == 1 + +def test_get_resource_path(mocker): + mocker.patch('os.getcwd', return_value='/') + mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/']) + + result = ResourceHelper.get_resource_path('test.txt') + + assert result == '/test.txt' + +def test_get_agent_resource_path(mocker): + mocker.patch('os.getcwd', return_value='/') + mocker.patch('os.makedirs') + mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/']) + + result = ResourceHelper.get_agent_resource_path('test.txt', 1) + + assert result == '/1/test.txt' diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/models/test_agent_execution_feed.py b/tests/unit_tests/models/test_agent_execution_feed.py new file mode 100644 index 000000000..15924c551 --- /dev/null +++ b/tests/unit_tests/models/test_agent_execution_feed.py @@ -0,0 +1,27 @@ +import pytest +from unittest.mock import Mock, create_autospec +from sqlalchemy.orm import Session +from superagi.models.agent_execution_feed import AgentExecutionFeed + + +def test_get_last_tool_response(): + mock_session = create_autospec(Session) + agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system') + agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system') + + mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2] + + result = AgentExecutionFeed.get_last_tool_response(mock_session, 2) + + assert result == agent_execution_feed_1.feed # as agent_execution_feed_1 should be the latest based on created_at + + +def test_get_last_tool_response_with_tool_name(): + mock_session = create_autospec(Session) + agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system') + agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system') + + mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2] + + result = AgentExecutionFeed.get_last_tool_response(mock_session, 2, "test2") + assert result == agent_execution_feed_2.feed diff --git a/tests/unit_tests/resource_manager/__init__.py b/tests/unit_tests/resource_manager/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/resource_manager/test_resource_manager.py b/tests/unit_tests/resource_manager/test_resource_manager.py new file mode 100644 index 000000000..a4630f0d5 --- /dev/null +++ b/tests/unit_tests/resource_manager/test_resource_manager.py @@ -0,0 +1,37 @@ +import pytest +from unittest.mock import Mock, patch +from superagi.models.resource import Resource +from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger + +from superagi.resource_manager.manager import ResourceManager + +@pytest.fixture +def resource_manager(): + session_mock = Mock() + resource_manager = ResourceManager(session_mock) + #resource_manager.agent_id = 1 # replace with actual value + return resource_manager + + +def test_write_binary_file(resource_manager): + with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \ + patch.object(ResourceHelper, 'make_written_file_resource', + return_value=Resource(name='test.png', storage_type='S3')), \ + patch.object(S3Helper, 'upload_file'), \ + patch.object(logger, 'info') as logger_mock: + result = resource_manager.write_binary_file('test.png', b'data') + assert result == "Binary test.png saved successfully" + logger_mock.assert_called_once_with("Binary test.png saved successfully") + + +def test_write_file(resource_manager): + with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \ + patch.object(ResourceHelper, 'make_written_file_resource', + return_value=Resource(name='test.txt', storage_type='S3')), \ + patch.object(S3Helper, 'upload_file'), \ + patch.object(logger, 'info') as logger_mock: + result = resource_manager.write_file('test.txt', 'content') + assert result == "test.txt saved successfully" + logger_mock.assert_called_once_with("test.txt saved successfully") diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/tools/test_dalle_image_gen.py b/tests/unit_tests/tools/test_dalle_image_gen.py new file mode 100644 index 000000000..82653d5a2 --- /dev/null +++ b/tests/unit_tests/tools/test_dalle_image_gen.py @@ -0,0 +1,27 @@ +from unittest.mock import Mock, patch +import pytest +from superagi.tools.image_generation.dalle_image_gen import DalleImageGenTool + + +class MockBaseLlm: + def generate_image(self, prompt, size, num): + return Mock(_previous={"data": [{"url": f"https://example.com/image_{i}.png"} for i in range(num)]}) + + +class TestDalleImageGenTool: + + @pytest.fixture + def tool(self): + tool = DalleImageGenTool() + tool.llm = MockBaseLlm() + response_mock = Mock() + tool.resource_manager = response_mock + return tool + + @patch("requests.get") + def test_execute(self, mock_get, tool): + mock_get.return_value = Mock(content=b"fake image data") + response = tool._execute("test prompt", ["test1.png", "test2.png"], size=512, num=2) + assert response == "Images downloaded successfully" + mock_get.assert_called_with("https://example.com/image_1.png") + assert tool.resource_manager.write_binary_file.call_count == 2 diff --git a/tests/unit_tests/tools/test_send_email.py b/tests/unit_tests/tools/test_send_email.py new file mode 100644 index 000000000..b0d80928e --- /dev/null +++ b/tests/unit_tests/tools/test_send_email.py @@ -0,0 +1,61 @@ +from superagi.tools.email.send_email import SendEmailTool +import pytest + +def test_send_to_draft(mocker): + + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config', autospec=True) + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "Draft", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('mukunda@contlo.com', 'Test Subject', 'Test Body') + + # Assert the return value + assert result == 'Email went to Draft' + +def test_send_to_mailbox(mocker): + # Mocking the get_config calls + mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config') + mock_get_config.side_effect = [ + 'test_sender@test.com', # EMAIL_ADDRESS + 'password', # EMAIL_PASSWORD + 'Test Signature', # EMAIL_SIGNATURE + "YOUR_DRAFTS_FOLDER", # EMAIL_DRAFT_MODE_WITH_FOLDER + 'smtp_host', # EMAIL_SMTP_HOST + 'smtp_port' # EMAIL_SMTP_PORT + ] + + # mock_get_config.return_value = 'True' + # Mocking the ImapEmail call + mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail') + mock_imap_instance = mock_imap_email.return_value.imap_open.return_value + + # Mocking the SMTP call + mock_smtp = mocker.patch('smtplib.SMTP') + smtp_instance = mock_smtp.return_value + + # Test the SendEmailTool's execute method + send_email_tool = SendEmailTool() + result = send_email_tool._execute('test_receiver@test.com', 'Test Subject', 'Test Body') + + # Assert that the ImapEmail was not called (no draft mode) + mock_imap_email.assert_not_called() + + # Assert the return value + assert result == 'Email was sent to test_receiver@test.com' \ No newline at end of file diff --git a/tests/unit_tests/tools/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py new file mode 100644 index 000000000..6d2dc75c0 --- /dev/null +++ b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py @@ -0,0 +1,51 @@ +import base64 +from io import BytesIO +from unittest.mock import patch, Mock + +import pytest +from PIL import Image + +from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool + + +def create_sample_image_base64(): + image = Image.new('RGBA', size=(50, 50), color=(73, 109, 137)) + byte_arr = BytesIO() + image.save(byte_arr, format='PNG') + encoded_image = base64.b64encode(byte_arr.getvalue()) + return encoded_image.decode('utf-8') + + +@pytest.fixture +def stable_diffusion_tool(): + with patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config') as get_config_mock, \ + patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \ + patch('superagi.tools.image_generation.stable_diffusion_image_gen.ResourceManager') as resource_manager_mock: + get_config_mock.return_value = 'fake_api_key' + + # Create a mock response object + response_mock = Mock() + response_mock.status_code = 200 + response_mock.json.return_value = { + 'artifacts': [{'base64': create_sample_image_base64()} for _ in range(2)] + } + post_mock.return_value = response_mock + + resource_manager_mock.write_binary_file.return_value = None + + yield + +def test_execute(stable_diffusion_tool): + tool = StableDiffusionImageGenTool() + tool.resource_manager = Mock() + result = tool._execute('prompt', ['img1.png', 'img2.png']) + + assert result == 'Images downloaded and saved successfully' + tool.resource_manager.write_binary_file.assert_called() + +def test_call_stable_diffusion(stable_diffusion_tool): + tool = StableDiffusionImageGenTool() + response = tool.call_stable_diffusion('fake_api_key', 512, 512, 2, 'prompt', 50) + + assert response.status_code == 200 + assert 'artifacts' in response.json() \ No newline at end of file diff --git a/tests/unit_tests/tools/test_write_code.py b/tests/unit_tests/tools/test_write_code.py new file mode 100644 index 000000000..76f3a5877 --- /dev/null +++ b/tests/unit_tests/tools/test_write_code.py @@ -0,0 +1,37 @@ +from unittest.mock import Mock, patch +import pytest + +from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager +from superagi.tools.code.write_code import CodingTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager + + +class MockBaseLlm: + def chat_completion(self, messages, max_tokens): + return {"content": "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```"} + + def get_model(self): + return "gpt-3.5-turbo" + +class TestCodingTool: + + @pytest.fixture + def tool(self): + tool = CodingTool() + tool.llm = MockBaseLlm() + tool.resource_manager = Mock(spec=ResourceManager) + tool.tool_response_manager = Mock(spec=ToolResponseQueryManager) + return tool + + def test_execute(self, tool): + tool.resource_manager.write_file.return_value = "File write successful" + tool.tool_response_manager.get_last_response.return_value = "Mocked Spec" + + response = tool._execute("Test spec description") + assert response == "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```\n Codes generated and saved successfully in File1.py, File2.py" + + tool.resource_manager.write_file.assert_any_call("README.md", 'File1.py\n') + tool.resource_manager.write_file.assert_any_call("File1.py", "print('Hello World')\n") + tool.resource_manager.write_file.assert_any_call("File2.py", "print('Hello again')\n") + tool.tool_response_manager.get_last_response.assert_called_once_with("WriteSpecTool") \ No newline at end of file diff --git a/tests/unit_tests/tools/test_write_spec.py b/tests/unit_tests/tools/test_write_spec.py new file mode 100644 index 000000000..05f85ed5c --- /dev/null +++ b/tests/unit_tests/tools/test_write_spec.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock + +import pytest + +from superagi.tools.code.write_spec import WriteSpecTool + + +class MockBaseLlm: + def chat_completion(self, messages, max_tokens): + return {"content": "Generated specification"} + + def get_model(self): + return "gpt-3.5-turbo" + +class TestWriteSpecTool: + + @pytest.fixture + def tool(self): + tool = WriteSpecTool() + tool.llm = MockBaseLlm() + tool.resource_manager = Mock() + return tool + + def test_execute(self, tool): + tool.resource_manager.write_file = Mock() + tool.resource_manager.write_file.return_value = "File write successful" + response = tool._execute("Test task description", "test_spec_file.txt") + assert response == "Generated specificationSpecification generated and saved successfully" + tool.resource_manager.write_file.assert_called_once_with("test_spec_file.txt", "Generated specification") diff --git a/tests/unit_tests/tools/test_write_test.py b/tests/unit_tests/tools/test_write_test.py new file mode 100644 index 000000000..3aba9c788 --- /dev/null +++ b/tests/unit_tests/tools/test_write_test.py @@ -0,0 +1,47 @@ +import pytest +from unittest.mock import Mock, patch +from superagi.llms.base_llm import BaseLlm +from superagi.resource_manager.manager import ResourceManager +from superagi.lib.logger import logger +from superagi.tools.code.write_test import WriteTestTool +from superagi.tools.tool_response_query_manager import ToolResponseQueryManager + + +def test_write_test_tool_init(): + tool = WriteTestTool() + assert tool.llm is None + assert tool.agent_id is None + assert tool.name == "WriteTestTool" + assert tool.description is not None + assert tool.goals == [] + assert tool.resource_manager is None + + +@patch('superagi.tools.code.write_test.logger') +@patch('superagi.tools.code.write_test.TokenCounter') +def test_write_test_tool_execute(mock_token_counter, mock_logger): + # Given + mock_llm = Mock(spec=BaseLlm) + mock_llm.get_model.return_value = None + mock_llm.chat_completion.return_value = {"content": "```python\nsample_code\n```"} + mock_token_counter.count_message_tokens.return_value = 0 + mock_token_counter.token_limit.return_value = 100 + + mock_resource_manager = Mock(spec=ResourceManager) + mock_resource_manager.write_file.return_value = "No error" + + mock_tool_response_manager = Mock(spec=ToolResponseQueryManager) + mock_tool_response_manager.get_last_response.return_value = "" + + tool = WriteTestTool() + tool.llm = mock_llm + tool.resource_manager = mock_resource_manager + tool.tool_response_manager = mock_tool_response_manager + + # When + result = tool._execute("spec", "test_file") + + # Then + mock_llm.chat_completion.assert_called_once() + mock_resource_manager.write_file.assert_called_once_with("test_file", "python\nsample_code") + assert "Tests generated and saved successfully in test_file" in result