diff --git a/gui/utils/utils.js b/gui/utils/utils.js
index 9cabc3936..86dccb2d5 100644
--- a/gui/utils/utils.js
+++ b/gui/utils/utils.js
@@ -79,4 +79,16 @@ export const refreshUrl = () => {
const urlWithoutToken = window.location.origin + window.location.pathname;
window.history.replaceState({}, document.title, urlWithoutToken);
-};
\ No newline at end of file
+};
+
+export const loadingTextEffect = (loadingText, setLoadingText, timer) => {
+ const text = loadingText;
+ let dots = '';
+
+ const interval = setInterval(() => {
+ dots = dots.length < 3 ? dots + '.' : '';
+ setLoadingText(`${text}${dots}`);
+ }, timer);
+
+ return () => clearInterval(interval)
+}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 71d7a6dd8..345d4f791 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -132,3 +132,7 @@ tiktoken==0.4.0
psycopg2==2.9.6
slack-sdk==3.21.3
pytest==7.3.2
+pylint==2.17.4
+pre-commit==3.3.3
+pytest-cov==4.1.0
+pytest-mock==3.11.1
diff --git a/superagi/config/config.py b/superagi/config/config.py
index 54067960f..421592845 100644
--- a/superagi/config/config.py
+++ b/superagi/config/config.py
@@ -25,17 +25,7 @@ def load_config(cls, config_file: str) -> dict:
logger.info("\033[91m\033[1m"
+ "\nConfig file not found. Enter required keys and values."
+ "\033[0m\033[0m")
- config_data = {
- "PINECONE_API_KEY": input("Pinecone API Key: "),
- "PINECONE_ENVIRONMENT": input("Pinecone Environment: "),
- # "OPENAI_API_KEY": input("OpenAI API Key: "),
- "GOOGLE_API_KEY": input("Google API Key: "),
- "SEARCH_ENGINE_ID": input("Search Engine ID: "),
- "RESOURCES_ROOT_DIR": input(
- "Resources Root Directory (default: /tmp/): "
- )
- or "/tmp/",
- }
+ config_data = {}
with open(config_file, "w") as file:
yaml.dump(config_data, file, default_flow_style=False)
diff --git a/superagi/helper/resource_helper.py b/superagi/helper/resource_helper.py
index 025e2e3b2..0449827f0 100644
--- a/superagi/helper/resource_helper.py
+++ b/superagi/helper/resource_helper.py
@@ -8,20 +8,19 @@
class ResourceHelper:
@staticmethod
- def make_written_file_resource(file_name: str, agent_id: int, file, channel):
+ def make_written_file_resource(file_name: str, agent_id: int, channel: str):
"""
Function to create a Resource object for a written file.
Args:
file_name (str): The name of the file.
agent_id (int): The ID of the agent.
- file (FileStorage): The file.
channel (str): The channel of the file.
Returns:
Resource: The Resource object.
"""
- path = get_config("RESOURCES_OUTPUT_ROOT_DIR")
+ path = ResourceHelper.get_root_dir()
storage_type = get_config("STORAGE_TYPE")
file_extension = os.path.splitext(file_name)[1][1:]
@@ -32,29 +31,59 @@ def make_written_file_resource(file_name: str, agent_id: int, file, channel):
else:
file_type = "application/misc"
- root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
-
- if root_dir is not None:
- root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + file_name
+ if agent_id is not None:
+ final_path = ResourceHelper.get_agent_resource_path(file_name, agent_id)
+ path = path + str(agent_id) + "/"
else:
- final_path = os.getcwd() + "/" + file_name
-
+ final_path = ResourceHelper.get_resource_path(file_name)
file_size = os.path.getsize(final_path)
if storage_type == "S3":
file_name_parts = file_name.split('.')
- file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '').replace('.', '').replace(
- ':', '') + '.' + file_name_parts[1]
- if channel == "INPUT":
- path = 'input'
- else:
- path = 'output'
-
- logger.info(path + "/" + file_name)
- resource = Resource(name=file_name, path=path + "/" + file_name, storage_type=storage_type, size=file_size,
+ file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '') \
+ .replace('.', '').replace(':', '') + '.' + file_name_parts[1]
+ path = 'input/' if (channel == "INPUT") else 'output/'
+
+ logger.info(final_path)
+ resource = Resource(name=file_name, path=path + file_name, storage_type=storage_type, size=file_size,
type=file_type,
channel="OUTPUT",
agent_id=agent_id)
return resource
+
+ @staticmethod
+ def get_resource_path(file_name: str):
+ """Get final path of the resource.
+
+ Args:
+ file_name (str): The name of the file.
+ """
+ return ResourceHelper.get_root_dir() + file_name
+
+ @staticmethod
+ def get_root_dir():
+ """Get root dir of the resource.
+ """
+ root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
+
+ if root_dir is not None:
+ root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
+ root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
+ else:
+ root_dir = os.getcwd() + "/"
+ return root_dir
+
+ @staticmethod
+ def get_agent_resource_path(file_name: str, agent_id: int):
+ """Get agent resource path
+
+ Args:
+ file_name (str): The name of the file.
+ """
+ root_dir = ResourceHelper.get_root_dir()
+ if agent_id is not None:
+ directory = os.path.dirname(root_dir + str(agent_id) + "/")
+ os.makedirs(directory, exist_ok=True)
+ root_dir = root_dir + str(agent_id) + "/"
+ final_path = root_dir + file_name
+ return final_path
diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py
index 636be0aea..8253b89dc 100644
--- a/superagi/jobs/agent_executor.py
+++ b/superagi/jobs/agent_executor.py
@@ -17,7 +17,9 @@
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.tool import Tool
+from superagi.resource_manager.manager import ResourceManager
from superagi.tools.thinking.tools import ThinkingTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.vector_factory import VectorFactory
from superagi.helper.encyption_helper import decrypt_data
@@ -164,7 +166,7 @@ def execute_next_action(self, agent_execution_id):
print(user_tools)
tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id,
- model_api_key=model_api_key)
+ model_api_key=model_api_key, session=session)
@@ -205,7 +207,7 @@ def execute_next_action(self, agent_execution_id):
# finally:
engine.dispose()
- def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key):
+ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key, session):
"""
Set the default parameters for the tools.
@@ -232,6 +234,12 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key
tool.image_llm = OpenAi(model=parsed_config["model"], api_key=model_api_key)
if hasattr(tool, 'agent_id'):
tool.agent_id = agent_id
+ if hasattr(tool, 'resource_manager'):
+ tool.resource_manager = ResourceManager(session=session, agent_id=agent_id)
+ if hasattr(tool, 'tool_response_manager'):
+ tool.tool_response_manager = ToolResponseQueryManager(session=session, agent_execution_id=parsed_config[
+ "agent_execution_id"])
+
new_tools.append(tool)
return tools
diff --git a/superagi/models/agent_execution_feed.py b/superagi/models/agent_execution_feed.py
index 9a9cb94f1..98acf5d26 100644
--- a/superagi/models/agent_execution_feed.py
+++ b/superagi/models/agent_execution_feed.py
@@ -1,4 +1,5 @@
from sqlalchemy import Column, Integer, Text, String
+from sqlalchemy.orm import Session
from superagi.models.base_model import DBBaseModel
@@ -36,3 +37,16 @@ def __repr__(self):
return f"AgentExecutionFeed(id={self.id}, " \
f"agent_execution_id={self.agent_execution_id}, " \
f"feed='{self.feed}', role='{self.role}', extra_info={self.extra_info})"
+
+ @classmethod
+ def get_last_tool_response(cls, session: Session, agent_execution_id: int, tool_name: str = None):
+ agent_execution_feeds = session.query(AgentExecutionFeed).filter(
+ AgentExecutionFeed.agent_execution_id == agent_execution_id,
+ AgentExecutionFeed.role == "system").order_by(AgentExecutionFeed.created_at.desc()).all()
+
+ for agent_execution_feed in agent_execution_feeds:
+ if tool_name and not agent_execution_feed.feed.startswith("Tool " + tool_name):
+ continue
+ if agent_execution_feed.feed.startswith("Tool"):
+ return agent_execution_feed.feed
+ return ""
diff --git a/tests/agent/__init__.py b/superagi/resource_manager/__init__.py
similarity index 100%
rename from tests/agent/__init__.py
rename to superagi/resource_manager/__init__.py
diff --git a/superagi/resource_manager/manager.py b/superagi/resource_manager/manager.py
new file mode 100644
index 000000000..565aa6276
--- /dev/null
+++ b/superagi/resource_manager/manager.py
@@ -0,0 +1,59 @@
+from sqlalchemy.orm import Session
+
+from superagi.helper.resource_helper import ResourceHelper
+from superagi.helper.s3_helper import S3Helper
+from superagi.lib.logger import logger
+import os
+
+
+class ResourceManager:
+ def __init__(self, session: Session, agent_id: int = None):
+ self.session = session
+ self.agent_id = agent_id
+
+ def write_binary_file(self, file_name: str, data):
+ if self.agent_id is not None:
+ final_path = ResourceHelper.get_agent_resource_path(file_name, self.agent_id)
+ else:
+ final_path = ResourceHelper.get_resource_path(file_name)
+
+ # if self.agent_id is not None:
+ # directory = os.path.dirname(final_path + "/" + str(self.agent_id) + "/")
+ # os.makedirs(directory, exist_ok=True)
+ try:
+ with open(final_path, mode="wb") as img:
+ img.write(data)
+ img.close()
+ self.write_to_s3(file_name, final_path)
+ logger.info(f"Binary {file_name} saved successfully")
+ return f"Binary {file_name} saved successfully"
+ except Exception as err:
+ return f"Error: {err}"
+
+ def write_to_s3(self, file_name, final_path):
+ with open(final_path, 'rb') as img:
+ resource = ResourceHelper.make_written_file_resource(file_name=file_name,
+ agent_id=self.agent_id, channel="OUTPUT")
+ if resource is not None:
+ self.session.add(resource)
+ self.session.commit()
+ self.session.flush()
+ if resource.storage_type == "S3":
+ s3_helper = S3Helper()
+ s3_helper.upload_file(img, path=resource.path)
+
+ def write_file(self, file_name: str, content):
+ if self.agent_id is not None:
+ final_path = ResourceHelper.get_agent_resource_path(file_name, self.agent_id)
+ else:
+ final_path = ResourceHelper.get_resource_path(file_name)
+
+ try:
+ with open(final_path, mode="w") as file:
+ file.write(content)
+ file.close()
+ self.write_to_s3(file_name, final_path)
+ logger.info(f"{file_name} saved successfully")
+ return f"{file_name} saved successfully"
+ except Exception as err:
+ return f"Error: {err}"
diff --git a/superagi/tools/code/tools.py b/superagi/tools/code/tools.py
deleted file mode 100644
index ca9b8ee29..000000000
--- a/superagi/tools/code/tools.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from typing import Type, Optional, List
-
-from pydantic import BaseModel, Field
-
-from superagi.agent.agent_prompt_builder import AgentPromptBuilder
-from superagi.llms.base_llm import BaseLlm
-from superagi.tools.base_tool import BaseTool
-from superagi.lib.logger import logger
-
-
-class CodingSchema(BaseModel):
- task_description: str = Field(
- ...,
- description="Coding task description.",
- )
-
-class CodingTool(BaseTool):
- """
- Used to generate code.
-
- Attributes:
- llm: LLM used for code generation.
- name : The name of tool.
- description : The description of tool.
- args_schema : The args schema.
- goals : The goals.
- """
- llm: Optional[BaseLlm] = None
- name = "CodingTool"
- description = (
- "Useful for writing, reviewing, and refactoring code. Can also fix bugs and explain programming concepts."
- )
- args_schema: Type[CodingSchema] = CodingSchema
- goals: List[str] = []
-
- class Config:
- arbitrary_types_allowed = True
-
-
- def _execute(self, task_description: str):
- """
- Execute the code tool.
-
- Args:
- task_description : The task description.
-
- Returns:
- Generated code or error message.
- """
- try:
- prompt = """You're a top-notch coder, knowing all programming languages, software systems, and architecture.
-
- Your high level goal is:
- {goals}
-
- Provide no information about who you are and focus on writing code.
- Ensure code is bug and error free and explain complex concepts through comments
- Respond in well-formatted markdown. Ensure code blocks are used for code sections.
-
- Write code to accomplish the following:
- {task}
- """
- prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
- prompt = prompt.replace("{task}", task_description)
- messages = [{"role": "system", "content": prompt}]
- result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
- return result["content"]
- except Exception as e:
- logger.error(e)
- return f"Error generating text: {e}"
\ No newline at end of file
diff --git a/superagi/tools/code/write_code.py b/superagi/tools/code/write_code.py
new file mode 100644
index 000000000..d322fc2d1
--- /dev/null
+++ b/superagi/tools/code/write_code.py
@@ -0,0 +1,143 @@
+import re
+from typing import Type, Optional, List
+
+from pydantic import BaseModel, Field
+
+from superagi.agent.agent_prompt_builder import AgentPromptBuilder
+from superagi.helper.token_counter import TokenCounter
+from superagi.lib.logger import logger
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
+from superagi.tools.base_tool import BaseTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+
+
+class CodingSchema(BaseModel):
+ code_description: str = Field(
+ ...,
+ description="Description of the coding task",
+ )
+
+class CodingTool(BaseTool):
+ """
+ Used to generate code.
+
+ Attributes:
+ llm: LLM used for code generation.
+ name : The name of tool.
+ description : The description of tool.
+ args_schema : The args schema.
+ goals : The goals.
+ resource_manager: Manages the file resources
+ """
+ llm: Optional[BaseLlm] = None
+ agent_id: int = None
+ name = "CodingTool"
+ description = (
+ "You will get instructions for code to write. You will write a very long answer. "
+ "Make sure that every detail of the architecture is, in the end, implemented as code. "
+ "Think step by step and reason yourself to the right decisions to make sure we get it right. "
+ "You will first lay out the names of the core classes, functions, methods that will be necessary, "
+ "as well as a quick comment on their purpose. Then you will output the content of each file including ALL code."
+ )
+ args_schema: Type[CodingSchema] = CodingSchema
+ goals: List[str] = []
+ resource_manager: Optional[ResourceManager] = None
+ tool_response_manager: Optional[ToolResponseQueryManager] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+ def _execute(self, code_description: str) -> str:
+ """
+ Execute the write_code tool.
+
+ Args:
+ code_description : The coding task description.
+ code_file_name: The name of the file where the generated codes will be saved.
+
+ Returns:
+ Generated codes files or error message.
+ """
+ try:
+ prompt = """You are a super smart developer who practices good Development for writing code according to a specification.
+
+ Your high-level goal is:
+ {goals}
+
+ Coding task description:
+ {code_description}
+
+ {spec}
+
+ You will get instructions for code to write.
+ You need to write a detailed answer. Make sure all parts of the architecture are turned into code.
+ Think carefully about each step and make good choices to get it right. First, list the main classes,
+ functions, methods you'll use and a quick comment on their purpose.
+
+ Then you will output the content of each file including ALL code.
+ Each file must strictly follow a markdown code block format, where the following tokens must be replaced such that
+ [FILENAME] is the lowercase file name including the file extension,
+ [LANG] is the markup code block language for the code's language, and [CODE] is the code:
+ [FILENAME]
+ ```[LANG]
+ [CODE]
+ ```
+
+ You will start with the "entrypoint" file, then go to the ones that are imported by that file, and so on.
+ Please note that the code should be fully functional. No placeholders.
+
+ Follow a language and framework appropriate best practice file naming convention.
+ Make sure that files contain all imports, types etc. Make sure that code in different files are compatible with each other.
+ Ensure to implement all code, if you are unsure, write a plausible implementation.
+ Include module dependency or package manager dependency definition file.
+ Before you finish, double check that all parts of the architecture is present in the files.
+ """
+ prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
+ prompt = prompt.replace("{code_description}", code_description)
+ spec_response = self.tool_response_manager.get_last_response("WriteSpecTool")
+ if spec_response != "":
+ prompt = prompt.replace("{spec}", "Use this specs for generating the code:\n" + spec_response)
+ logger.info(prompt)
+ messages = [{"role": "system", "content": prompt}]
+
+ total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
+ token_limit = TokenCounter.token_limit(self.llm.get_model())
+ result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
+
+ # Get all filenames and corresponding code blocks
+ regex = r"(\S+?)\n```\S*\n(.+?)```"
+ matches = re.finditer(regex, result["content"], re.DOTALL)
+
+ file_names = []
+ # Save each file
+
+ for match in matches:
+ # Get the filename
+ file_name = re.sub(r'[<>"|?*]', "", match.group(1))
+
+ # Get the code
+ code = match.group(2)
+
+ # Ensure file_name is not empty
+ if not file_name.strip():
+ continue
+
+ file_names.append(file_name)
+ save_result = self.resource_manager.write_file(file_name, code)
+ if save_result.startswith("Error"):
+ return save_result
+
+ # Get README contents and save
+ split_result = result["content"].split("```")
+ if len(split_result) > 0:
+ readme = split_result[0]
+ save_readme_result = self.resource_manager.write_file("README.md", readme)
+ if save_readme_result.startswith("Error"):
+ return save_readme_result
+
+ return result["content"] + "\n Codes generated and saved successfully in " + ", ".join(file_names)
+ except Exception as e:
+ logger.error(e)
+ return f"Error generating codes: {e}"
diff --git a/superagi/tools/code/write_spec.py b/superagi/tools/code/write_spec.py
new file mode 100644
index 000000000..f15c89b15
--- /dev/null
+++ b/superagi/tools/code/write_spec.py
@@ -0,0 +1,97 @@
+from typing import Type, Optional, List
+
+from pydantic import BaseModel, Field
+from superagi.config.config import get_config
+from superagi.agent.agent_prompt_builder import AgentPromptBuilder
+import os
+
+from superagi.helper.token_counter import TokenCounter
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
+from superagi.tools.base_tool import BaseTool
+from superagi.lib.logger import logger
+from superagi.models.db import connect_db
+from superagi.helper.resource_helper import ResourceHelper
+from superagi.helper.s3_helper import S3Helper
+from sqlalchemy.orm import sessionmaker
+
+
+class WriteSpecSchema(BaseModel):
+ task_description: str = Field(
+ ...,
+ description="Specification task description.",
+ )
+
+ spec_file_name: str = Field(
+ ...,
+ description="Name of the file to write. Only include the file name. Don't include path."
+ )
+
+class WriteSpecTool(BaseTool):
+ """
+ Used to generate program specification.
+
+ Attributes:
+ llm: LLM used for specification generation.
+ name : The name of tool.
+ description : The description of tool.
+ args_schema : The args schema.
+ goals : The goals.
+ resource_manager: Manages the file resources
+ """
+ llm: Optional[BaseLlm] = None
+ agent_id: int = None
+ name = "WriteSpecTool"
+ description = (
+ "A tool to write the spec of a program."
+ )
+ args_schema: Type[WriteSpecSchema] = WriteSpecSchema
+ goals: List[str] = []
+ resource_manager: Optional[ResourceManager] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def _execute(self, task_description: str, spec_file_name: str) -> str:
+ """
+ Execute the write_spec tool.
+
+ Args:
+ task_description : The task description.
+ spec_file_name: The name of the file where the generated specification will be saved.
+
+ Returns:
+ Generated specification or error message.
+ """
+ try:
+ prompt = """You are a super smart developer who has been asked to make a specification for a program.
+
+ Your high-level goal is:
+ {goals}
+
+ Please keep in mind the following when creating the specification:
+ 1. Be super explicit about what the program should do, which features it should have, and give details about anything that might be unclear.
+ 2. Lay out the names of the core classes, functions, methods that will be necessary, as well as a quick comment on their purpose.
+ 3. List all non-standard dependencies that will have to be used.
+
+ Write a specification for the following task:
+ {task}
+ """
+ prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
+ prompt = prompt.replace("{task}", task_description)
+ messages = [{"role": "system", "content": prompt}]
+
+ total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
+ token_limit = TokenCounter.token_limit(self.llm.get_model())
+ result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
+
+ # Save the specification to a file
+ write_result = self.resource_manager.write_file(spec_file_name, result["content"])
+ if not write_result.startswith("Error"):
+ return result["content"] + "Specification generated and saved successfully"
+ else:
+ return write_result
+
+ except Exception as e:
+ logger.error(e)
+ return f"Error generating specification: {e}"
\ No newline at end of file
diff --git a/superagi/tools/code/write_test.py b/superagi/tools/code/write_test.py
new file mode 100644
index 000000000..0060b2cb3
--- /dev/null
+++ b/superagi/tools/code/write_test.py
@@ -0,0 +1,110 @@
+import re
+from typing import Type, Optional, List
+
+from pydantic import BaseModel, Field
+
+from superagi.agent.agent_prompt_builder import AgentPromptBuilder
+from superagi.helper.token_counter import TokenCounter
+from superagi.lib.logger import logger
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
+from superagi.tools.base_tool import BaseTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+
+
+class WriteTestSchema(BaseModel):
+ test_description: str = Field(
+ ...,
+ description="Description of the testing task",
+ )
+ test_file_name: str = Field(
+ ...,
+ description="Name of the file to write. Only include the file name. Don't include path."
+ )
+
+
+class WriteTestTool(BaseTool):
+ """
+ Used to generate pytest unit tests based on the specification.
+
+ Attributes:
+ llm: LLM used for test generation.
+ name : The name of tool.
+ description : The description of tool.
+ args_schema : The args schema.
+ goals : The goals.
+ resource_manager: Manages the file resources
+ """
+ llm: Optional[BaseLlm] = None
+ agent_id: int = None
+ name = "WriteTestTool"
+ description = (
+ "You are a super smart developer using Test Driven Development to write tests according to a specification.\n"
+ "Please generate tests based on the above specification. The tests should be as simple as possible, "
+ "but still cover all the functionality.\n"
+ "Write it in the file"
+ )
+ args_schema: Type[WriteTestSchema] = WriteTestSchema
+ goals: List[str] = []
+ resource_manager: Optional[ResourceManager] = None
+ tool_response_manager: Optional[ToolResponseQueryManager] = None
+
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def _execute(self, test_description: str, test_file_name: str) -> str:
+ """
+ Execute the write_test tool.
+
+ Args:
+ test_description : The specification description.
+ test_file_name: The name of the file where the generated tests will be saved.
+
+ Returns:
+ Generated pytest unit tests or error message.
+ """
+ try:
+ prompt = """You are a super smart developer who practices Test Driven Development for writing tests according to a specification.
+
+ Your high-level goal is:
+ {goals}
+
+ Test Description:
+ {test_description}
+
+ {spec}
+
+ The tests should be as simple as possible, but still cover all the functionality described in the specification.
+ """
+ prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
+ prompt = prompt.replace("{test_description}", test_description)
+
+ spec_response = self.tool_response_manager.get_last_response("WriteSpecTool")
+ if spec_response != "":
+ prompt = prompt.replace("{spec}", "Please generate unit tests based on the following specification description:\n" + spec_response)
+
+ messages = [{"role": "system", "content": prompt}]
+ logger.info(prompt)
+
+ total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
+ token_limit = TokenCounter.token_limit(self.llm.get_model())
+ result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
+
+ # Extract the code part using regular expression
+ code = re.search(r'(?<=```).*?(?=```)', result["content"], re.DOTALL)
+ if code:
+ code_content = code.group(0).strip()
+ else:
+ return "Unable to extract code from the response"
+
+ # Save the tests to a file
+ save_result = self.resource_manager.write_file(test_file_name, code_content)
+ if not save_result.startswith("Error"):
+ return result["content"] + " \n Tests generated and saved successfully in " + test_file_name
+ else:
+ return save_result
+
+ except Exception as e:
+ logger.error(e)
+ return f"Error generating tests: {e}"
\ No newline at end of file
diff --git a/tests/agent_permissions/__init__.py b/superagi/tools/email/__init__.py
similarity index 100%
rename from tests/agent_permissions/__init__.py
rename to superagi/tools/email/__init__.py
diff --git a/superagi/tools/email/send_email.py b/superagi/tools/email/send_email.py
index 2d6641cf4..d475e2a71 100644
--- a/superagi/tools/email/send_email.py
+++ b/superagi/tools/email/send_email.py
@@ -57,7 +57,7 @@ def _execute(self, to: str, subject: str, body: str) -> str:
body += f"\n{signature}"
message.set_content(body)
draft_folder = get_config('EMAIL_DRAFT_MODE_WITH_FOLDER')
- send_to_draft = draft_folder is not None or draft_folder != "YOUR_DRAFTS_FOLDER"
+ send_to_draft = draft_folder is not None and draft_folder != "YOUR_DRAFTS_FOLDER"
if message["To"] == "example@example.com" or send_to_draft:
conn = ImapEmail().imap_open(draft_folder, email_sender, email_password)
conn.append(
diff --git a/superagi/tools/file/append_file.py b/superagi/tools/file/append_file.py
index b2f62d30a..72f3fae5f 100644
--- a/superagi/tools/file/append_file.py
+++ b/superagi/tools/file/append_file.py
@@ -3,6 +3,7 @@
from pydantic import BaseModel, Field
from superagi.config.config import get_config
+from superagi.helper.resource_helper import ResourceHelper
from superagi.tools.base_tool import BaseTool
@@ -38,13 +39,7 @@ def _execute(self, file_name: str, content: str):
Returns:
file written to successfully. or error message.
"""
- final_path = file_name
- root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
- if root_dir is not None:
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + file_name
- else:
- final_path = os.getcwd() + "/" + file_name
+ final_path = ResourceHelper.get_resource_path(file_name)
try:
directory = os.path.dirname(final_path)
os.makedirs(directory, exist_ok=True)
diff --git a/superagi/tools/file/delete_file.py b/superagi/tools/file/delete_file.py
index cba875cb5..3917f0a1a 100644
--- a/superagi/tools/file/delete_file.py
+++ b/superagi/tools/file/delete_file.py
@@ -3,6 +3,7 @@
from pydantic import BaseModel, Field
+from superagi.helper.resource_helper import ResourceHelper
from superagi.tools.base_tool import BaseTool
from superagi.config.config import get_config
@@ -36,13 +37,7 @@ def _execute(self, file_name: str, content: str):
Returns:
file deleted successfully. or error message.
"""
- final_path = file_name
- root_dir = get_config('RESOURCES_INPUT_ROOT_DIR')
- if root_dir is not None:
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + file_name
- else:
- final_path = os.getcwd() + "/" + file_name
+ final_path = ResourceHelper.get_resource_path(file_name)
try:
os.remove(final_path)
return "File deleted successfully."
diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py
index d11247e06..f425a3186 100644
--- a/superagi/tools/file/write_file.py
+++ b/superagi/tools/file/write_file.py
@@ -1,16 +1,12 @@
-import os
-from typing import Type
+from typing import Type, Optional
+
from pydantic import BaseModel, Field
+
+from superagi.resource_manager.manager import ResourceManager
from superagi.tools.base_tool import BaseTool
-from superagi.config.config import get_config
-from sqlalchemy.orm import sessionmaker
-from superagi.models.db import connect_db
-from superagi.helper.resource_helper import ResourceHelper
-# from superagi.helper.s3_helper import upload_to_s3
-from superagi.helper.s3_helper import S3Helper
-from superagi.lib.logger import logger
+# from superagi.helper.s3_helper import upload_to_s3
class WriteFileInput(BaseModel):
@@ -32,6 +28,10 @@ class WriteFileTool(BaseTool):
args_schema: Type[BaseModel] = WriteFileInput
description: str = "Writes text to a file"
agent_id: int = None
+ resource_manager: Optional[ResourceManager] = None
+
+ class Config:
+ arbitrary_types_allowed = True
def _execute(self, file_name: str, content: str):
"""
@@ -44,35 +44,5 @@ def _execute(self, file_name: str, content: str):
Returns:
file written to successfully. or error message.
"""
- engine = connect_db()
- Session = sessionmaker(bind=engine)
- session = Session()
-
- final_path = file_name
- root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
- if root_dir is not None:
- root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + file_name
- else:
- final_path = os.getcwd() + "/" + file_name
+ self.resource_manager.write_file(file_name, content)
- try:
- with open(final_path, 'w', encoding="utf-8") as file:
- file.write(content)
- file.close()
- with open(final_path, 'rb') as file:
- resource = ResourceHelper.make_written_file_resource(file_name=file_name,
- agent_id=self.agent_id,file=file,channel="OUTPUT")
- if resource is not None:
- session.add(resource)
- session.commit()
- session.flush()
- if resource.storage_type == "S3":
- s3_helper = S3Helper()
- s3_helper.upload_file(file, path=resource.path)
- logger.info("Resource Uploaded to S3!")
- session.close()
- return f"File written to successfully - {file_name}"
- except Exception as err:
- return f"Error: {err}"
diff --git a/superagi/tools/image_generation/README.STABLE_DIFFUSION.md b/superagi/tools/image_generation/README.STABLE_DIFFUSION.md
new file mode 100644
index 000000000..f4181544b
--- /dev/null
+++ b/superagi/tools/image_generation/README.STABLE_DIFFUSION.md
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+## SuperAGI Stable Diffusion Toolkit
+
+Introducing Stable Diffusion Integration with SuperAGI
+
+You can now use SuperAGI to summon Stable Diffusion to create true-to-life images and opens up a whole new range of possibilities.
+
+# ⚙️ Installation
+
+## 🛠️ Setting up SuperAGI
+
+Set-up SuperAGI by following the instruction given [here](https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
+
+## 🔧Configuring API from DreamStudio
+
+You can now get your API Key from Dream Studio to use Stable Diffusion by following the instructions below:
+
+1. Create an Account/Login with [DreamStudio.ai](http://DreamStudio.ai)
+
+![SD_1](README/SD_1.jpg)
+
+1. Click on the Profile Icon at the top right which will take you to the settings page. Once you have reached the settings page, you can now get your API keys
+
+![SD_2](README/SD_2.jpg)
+
+1. Copy the API Key and save it in a separate file.
+
+## 🛠️Configuring Stable Diffusion with SuperAGI
+
+You can configure SuperAGI with Stable Diffusion using the following steps:
+
+1. Navigate to the “****************Toolkit”**************** Page in SuperAGI’s Dashboard and select “****************Image Generation Toolkit”****************
+
+![SD_3](README/SD_3.jpg)
+
+1. Once you’ve clicked Image Generation Toolkit, it will open a page asking you for the API Key and the Model Engine. You can enter the generated API key from Dream Studio here.
+
+![SD_4](README/SD_4.jpg)
+3. If you would like to get more in-depth with the model of Stable Diffusion you’d like to use, you can choose between the following engine IDs:
+
+- 'stable-diffusion-v1'
+- 'stable-diffusion-v1-5'
+- 'stable-diffusion-512-v2-0'
+- 'stable-diffusion-768-v2-0'
+- 'stable-diffusion-512-v2-1'
+- ’stable-diffusion-768-v2-1'
+- 'stable-diffusion-xl-beta-v2-2-2’
+
+You have now successfully configured Stable Diffusion with SuperAGI!
\ No newline at end of file
diff --git a/superagi/tools/image_generation/README/SD_1.jpg b/superagi/tools/image_generation/README/SD_1.jpg
new file mode 100644
index 000000000..becead11b
Binary files /dev/null and b/superagi/tools/image_generation/README/SD_1.jpg differ
diff --git a/superagi/tools/image_generation/README/SD_2.jpg b/superagi/tools/image_generation/README/SD_2.jpg
new file mode 100644
index 000000000..01378871a
Binary files /dev/null and b/superagi/tools/image_generation/README/SD_2.jpg differ
diff --git a/superagi/tools/image_generation/README/SD_3.jpg b/superagi/tools/image_generation/README/SD_3.jpg
new file mode 100644
index 000000000..8e1a6450e
Binary files /dev/null and b/superagi/tools/image_generation/README/SD_3.jpg differ
diff --git a/superagi/tools/image_generation/README/SD_4.jpg b/superagi/tools/image_generation/README/SD_4.jpg
new file mode 100644
index 000000000..67a3e1c9b
Binary files /dev/null and b/superagi/tools/image_generation/README/SD_4.jpg differ
diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py
index cd80847f1..2b120efc2 100644
--- a/superagi/tools/image_generation/dalle_image_gen.py
+++ b/superagi/tools/image_generation/dalle_image_gen.py
@@ -1,44 +1,42 @@
from typing import Type, Optional
+
+import requests
from pydantic import BaseModel, Field
+
from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
from superagi.tools.base_tool import BaseTool
-from superagi.config.config import get_config
-import os
-import requests
-from superagi.models.db import connect_db
-from superagi.helper.resource_helper import ResourceHelper
-from superagi.helper.s3_helper import S3Helper
-from sqlalchemy.orm import sessionmaker
-from superagi.lib.logger import logger
-
-
-class ImageGenInput(BaseModel):
+class DalleImageGenInput(BaseModel):
prompt: str = Field(..., description="Prompt for Image Generation to be used by Dalle.")
size: int = Field(..., description="Size of the image to be Generated. default size is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
- image_name: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
+ image_names: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
-class ImageGenTool(BaseTool):
+class DalleImageGenTool(BaseTool):
"""
Dalle Image Generation tool
Attributes:
- name : The name.
- description : The description.
- args_schema : The args schema.
+ name : Name of the tool
+ description : The description
+ args_schema : The args schema
+ llm : The llm
+ agent_id : The agent id
+ resource_manager : Manages the file resources
"""
- name: str = "Dalle Image Generation"
- args_schema: Type[BaseModel] = ImageGenInput
+ name: str = "DalleImageGeneration"
+ args_schema: Type[BaseModel] = DalleImageGenInput
description: str = "Generate Images using Dalle"
llm: Optional[BaseLlm] = None
agent_id: int = None
+ resource_manager: Optional[ResourceManager] = None
class Config:
arbitrary_types_allowed = True
- def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2):
+ def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2):
"""
Execute the Dalle Image Generation tool.
@@ -46,47 +44,17 @@ def _execute(self, prompt: str, image_name: list, size: int = 512, num: int = 2)
prompt : The prompt for image generation.
size : The size of the image to be generated.
num : The number of images to be generated.
- image_name (list): The name of the image to be generated.
+ image_names (list): The name of the image to be generated.
Returns:
Image generated successfully. or error message.
"""
- engine = connect_db()
- Session = sessionmaker(bind=engine)
- session = Session()
if size not in [256, 512, 1024]:
size = min([256, 512, 1024], key=lambda x: abs(x - size))
response = self.llm.generate_image(prompt, size, num)
response = response.__dict__
response = response['_previous']['data']
for i in range(num):
- image = image_name[i]
- final_path = image
- root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
- if root_dir is not None:
- root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + image
- else:
- final_path = os.getcwd() + "/" + image
- url = response[i]['url']
- data = requests.get(url).content
- try:
- with open(final_path, mode="wb") as img:
- img.write(data)
- with open(final_path, 'rb') as img:
- resource = ResourceHelper.make_written_file_resource(file_name=image_name[i],
- agent_id=self.agent_id, file=img,channel="OUTPUT")
- if resource is not None:
- session.add(resource)
- session.commit()
- session.flush()
- if resource.storage_type == "S3":
- s3_helper = S3Helper()
- s3_helper.upload_file(img, path=resource.path)
- logger.info(f"Image {image} saved successfully")
- except Exception as err:
- session.close()
- return f"Error: {err}"
- session.close()
+ data = requests.get(response[i]['url']).content
+ self.resource_manager.write_binary_file(image_names[i], data)
return "Images downloaded successfully"
diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py
index be16014ad..3f615d650 100644
--- a/superagi/tools/image_generation/stable_diffusion_image_gen.py
+++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py
@@ -1,17 +1,13 @@
+import base64
+from io import BytesIO
from typing import Type, Optional
+
+import requests
+from PIL import Image
from pydantic import BaseModel, Field
-from superagi.tools.base_tool import BaseTool
from superagi.config.config import get_config
-import os
-from PIL import Image
-from io import BytesIO
-import requests
-import base64
-from superagi.models.db import connect_db
-from superagi.helper.resource_helper import ResourceHelper
-from superagi.helper.s3_helper import S3Helper
-from sqlalchemy.orm import sessionmaker
-from superagi.lib.logger import logger
+from superagi.resource_manager.manager import ResourceManager
+from superagi.tools.base_tool import BaseTool
class StableDiffusionImageGenInput(BaseModel):
@@ -20,22 +16,32 @@ class StableDiffusionImageGenInput(BaseModel):
width: int = Field(..., description="Width of the image to be Generated. default width is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
steps: int = Field(..., description="Number of diffusion steps to run. default steps are 50")
- image_name: list = Field(...,
- description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
+ image_names: list = Field(...,
+ description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")
class StableDiffusionImageGenTool(BaseTool):
+ """
+ Stable diffusion Image Generation tool
+
+ Attributes:
+ name : Name of the tool
+ description : The description
+ args_schema : The args schema
+ agent_id : The agent id
+ resource_manager : Manages the file resources
+ """
name: str = "Stable Diffusion Image Generation"
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None
+ resource_manager: Optional[ResourceManager] = None
- def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2,
- steps: int = 50):
- engine = connect_db()
- Session = sessionmaker(bind=engine)
- session = Session()
+ class Config:
+ arbitrary_types_allowed = True
+ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int = 512, num: int = 2,
+ steps: int = 50):
api_key = get_config("STABILITY_API_KEY")
if api_key is None:
@@ -58,21 +64,11 @@ def _execute(self, prompt: str, image_name: list, width: int = 512, height: int
img_data = base64.b64decode(image_base64)
final_img = Image.open(BytesIO(img_data))
image_format = final_img.format
+ img_byte_arr = BytesIO()
+ final_img.save(img_byte_arr, format=image_format)
- image = image_name[i]
- root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
-
- final_path = self.build_file_path(image, root_dir)
+ self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue())
- try:
- self.upload_to_s3(final_img, final_path, image_format, image_name[i], session)
-
- logger.info(f"Image {image} saved successfully")
- except Exception as err:
- session.close()
- print(f"Error in _execute: {err}")
- return f"Error: {err}"
- session.close()
return "Images downloaded and saved successfully"
def call_stable_diffusion(self, api_key, width, height, num, prompt, steps):
@@ -90,11 +86,7 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps):
"Authorization": f"Bearer {api_key}"
},
json={
- "text_prompts": [
- {
- "text": prompt
- }
- ],
+ "text_prompts": [{"text": prompt}],
"height": height,
"width": width,
"samples": num,
@@ -102,27 +94,3 @@ def call_stable_diffusion(self, api_key, width, height, num, prompt, steps):
},
)
return response
-
- def upload_to_s3(self, final_img, final_path, image_format, file_name, session):
- with open(final_path, mode="wb") as img:
- final_img.save(img, format=image_format)
- with open(final_path, 'rb') as img:
- resource = ResourceHelper.make_written_file_resource(file_name=file_name,
- agent_id=self.agent_id, file=img, channel="OUTPUT")
- logger.info(resource)
- if resource is not None:
- session.add(resource)
- session.commit()
- session.flush()
- if resource.storage_type == "S3":
- s3_helper = S3Helper()
- s3_helper.upload_file(img, path=resource.path)
-
- def build_file_path(self, image, root_dir):
- if root_dir is not None:
- root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
- root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
- final_path = root_dir + image
- else:
- final_path = os.getcwd() + "/" + image
- return final_path
diff --git a/superagi/tools/slack/README.md b/superagi/tools/slack/README.md
new file mode 100644
index 000000000..bfb5bb4de
--- /dev/null
+++ b/superagi/tools/slack/README.md
@@ -0,0 +1,57 @@
+
+
+
+
+# SuperAGI Slack Toolkit
+
+This SuperAGI Tool lets users send messages to Slack Channels and provides a strong foundation for use cases to come.
+
+**Features:**
+
+1. Send Message - This tool gives SuperAGI the ability to send messages to Slack Channels that you have specified.
+
+## 🛠️ Installation
+
+Setting up of SuperAGI:
+
+Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
+
+### 🔧 **Slack Configuration:**
+
+1. Create an Application on SlackAPI Portal
+
+ ![Slack_1](/README/Slack_1.jpg)
+
+2. Select "from scratch"
+
+ ![Slack_2](README/Slack_2.jpg)
+
+3. Add your application's name and the workspace for which you'd like to use your Slack Application on
+
+ ![Slack_3](README/Slack_3.jpg)
+
+4. Once the app creation process is done, head to the "OAuth and Permissions" tab
+
+ ![Slack_4](README/Slack_4.jpg)
+
+5. Find the “**bot token scopes”** and define the following scopes:
+
+ **"chat:write",** and save it
+
+ ![Slack_5](README/Slack_5.jpg)
+
+6. Once you've defined the scope, install the application to your workspace.
+
+
+ ![Slack_6](README/Slack_6.jpg)
+
+7. Post installation, you will get the bot token code
+
+
+ ![Slack_7](README/Slack_7.jpg)
+
+8. Once the installation is done, you'll get the Bot User OAuth Token, which needs to be added in the config.yaml beside the **"slack_bot_token"** variable.
+
+![Slack_8](README/Slack_8.jpg)
+
+Once the configuration is complete, you can install the app in the channel of your choice and create an agent on SuperAGI which can now send messages to the Slack Channel!
\ No newline at end of file
diff --git a/superagi/tools/slack/README/Slack_1.jpg b/superagi/tools/slack/README/Slack_1.jpg
new file mode 100644
index 000000000..8d37a189a
Binary files /dev/null and b/superagi/tools/slack/README/Slack_1.jpg differ
diff --git a/superagi/tools/slack/README/Slack_2.jpg b/superagi/tools/slack/README/Slack_2.jpg
new file mode 100644
index 000000000..e9f4d392f
Binary files /dev/null and b/superagi/tools/slack/README/Slack_2.jpg differ
diff --git a/superagi/tools/slack/README/Slack_3.jpg b/superagi/tools/slack/README/Slack_3.jpg
new file mode 100644
index 000000000..23b057a9e
Binary files /dev/null and b/superagi/tools/slack/README/Slack_3.jpg differ
diff --git a/superagi/tools/slack/README/Slack_4.jpg b/superagi/tools/slack/README/Slack_4.jpg
new file mode 100644
index 000000000..35694a79d
Binary files /dev/null and b/superagi/tools/slack/README/Slack_4.jpg differ
diff --git a/superagi/tools/slack/README/Slack_5.jpg b/superagi/tools/slack/README/Slack_5.jpg
new file mode 100644
index 000000000..d1b1acd96
Binary files /dev/null and b/superagi/tools/slack/README/Slack_5.jpg differ
diff --git a/superagi/tools/slack/README/Slack_6.jpg b/superagi/tools/slack/README/Slack_6.jpg
new file mode 100644
index 000000000..2b361b479
Binary files /dev/null and b/superagi/tools/slack/README/Slack_6.jpg differ
diff --git a/superagi/tools/slack/README/Slack_7.jpg b/superagi/tools/slack/README/Slack_7.jpg
new file mode 100644
index 000000000..c9305134c
Binary files /dev/null and b/superagi/tools/slack/README/Slack_7.jpg differ
diff --git a/superagi/tools/slack/README/Slack_8.jpg b/superagi/tools/slack/README/Slack_8.jpg
new file mode 100644
index 000000000..3016938df
Binary files /dev/null and b/superagi/tools/slack/README/Slack_8.jpg differ
diff --git a/superagi/tools/thinking/tools.py b/superagi/tools/thinking/tools.py
index 50c4c699a..302b3374d 100644
--- a/superagi/tools/thinking/tools.py
+++ b/superagi/tools/thinking/tools.py
@@ -1,15 +1,12 @@
-import os
-import openai
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
-from superagi.tools.base_tool import BaseTool
-from superagi.config.config import get_config
-from superagi.llms.base_llm import BaseLlm
-from pydantic import BaseModel, Field, PrivateAttr
from superagi.lib.logger import logger
+from superagi.llms.base_llm import BaseLlm
+from superagi.tools.base_tool import BaseTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
class ThinkingSchema(BaseModel):
@@ -36,6 +33,7 @@ class ThinkingTool(BaseTool):
args_schema: Type[ThinkingSchema] = ThinkingSchema
goals: List[str] = []
permission_required: bool = False
+ tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
arbitrary_types_allowed = True
@@ -58,13 +56,17 @@ def _execute(self, task_description: str):
and the following task, `{task_description}`.
+ Below is last tool response:
+ `{last_tool_response}`
+
Perform the task by understanding the problem, extracting variables, and being smart
and efficient. Provide a descriptive response, make decisions yourself when
confronted with choices and provide reasoning for ideas / decisions.
"""
prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
prompt = prompt.replace("{task_description}", task_description)
-
+ last_tool_response = self.tool_response_manager.get_last_response()
+ prompt = prompt.replace("{last_tool_response}", last_tool_response)
messages = [{"role": "system", "content": prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
return result["content"]
diff --git a/superagi/tools/tool_response_query_manager.py b/superagi/tools/tool_response_query_manager.py
new file mode 100644
index 000000000..5a2387648
--- /dev/null
+++ b/superagi/tools/tool_response_query_manager.py
@@ -0,0 +1,12 @@
+from sqlalchemy.orm import Session
+
+from superagi.models.agent_execution_feed import AgentExecutionFeed
+
+
+class ToolResponseQueryManager:
+ def __init__(self, session: Session, agent_execution_id: int):
+ self.session = session
+ self.agent_execution_id = agent_execution_id
+
+ def get_last_response(self, tool_name: str = None):
+ return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name)
diff --git a/superagi/worker.py b/superagi/worker.py
index c107da67e..243eda64d 100644
--- a/superagi/worker.py
+++ b/superagi/worker.py
@@ -4,7 +4,7 @@
from celery import Celery
from superagi.config.config import get_config
-redis_url = get_config('REDIS_URL')
+redis_url = get_config('REDIS_URL') or 'localhost:6379'
app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"])
app.conf.broker_url = "redis://" + redis_url + "/0"
diff --git a/test.py b/test.py
index f8591dc46..eb466739b 100644
--- a/test.py
+++ b/test.py
@@ -38,23 +38,23 @@ def ask_user_for_goals():
return goals
-
-def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None):
+def run_superagi_cli(agent_name=None, agent_description=None, agent_goals=None):
# Create default organization
organization = Organisation(name='Default Organization', description='Default organization description')
session.add(organization)
session.flush() # Flush pending changes to generate the agent's ID
session.commit()
logger.info(organization)
-
+
# Create default project associated with the organization
- project = Project(name='Default Project', description='Default project description', organisation_id=organization.id)
+ project = Project(name='Default Project', description='Default project description',
+ organisation_id=organization.id)
session.add(project)
session.flush() # Flush pending changes to generate the agent's ID
session.commit()
logger.info(project)
- #Agent
+ # Agent
if agent_name is None:
agent_name = input("Enter agent name: ")
if agent_description is None:
@@ -65,24 +65,24 @@ def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None):
session.commit()
logger.info(agent)
- #Agent Config
+ # Agent Config
# Create Agent Configuration
agent_config_values = {
"goal": ask_user_for_goals() if agent_goals is None else agent_goals,
"agent_type": "Type Non-Queue",
- "constraints": [ "~4000 word limit for short term memory. ",
- "Your short term memory is short, so immediately save important information to files.",
- "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
- "No user assistance",
- "Exclusively use the commands listed in double quotes e.g. \"command name\""
- ],
+ "constraints": ["~4000 word limit for short term memory. ",
+ "Your short term memory is short, so immediately save important information to files.",
+ "If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
+ "No user assistance",
+ "Exclusively use the commands listed in double quotes e.g. \"command name\""
+ ],
"tools": [],
"exit": "Default",
"iteration_interval": 0,
"model": "gpt-4",
"permission_type": "Default",
"LTM_DB": "Pinecone",
- "memory_window":10
+ "memory_window": 10
}
# print("Id is ")
@@ -106,5 +106,6 @@ def run_superagi_cli(agent_name=None,agent_description=None,agent_goals=None):
logger.info(execution)
execute_agent.delay(execution.id, datetime.now())
-
-run_superagi_cli(agent_name=agent_name,agent_description=agent_description,agent_goals=agent_goals)
\ No newline at end of file
+
+
+run_superagi_cli(agent_name=agent_name, agent_description=agent_description, agent_goals=agent_goals)
diff --git a/tests/helper/__init__.py b/tests/integration_tests/__init__.py
similarity index 100%
rename from tests/helper/__init__.py
rename to tests/integration_tests/__init__.py
diff --git a/tests/integration_tests/vector_store/__init__.py b/tests/integration_tests/vector_store/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/vector_store/test_weaviate.py b/tests/integration_tests/vector_store/test_weaviate.py
similarity index 100%
rename from tests/vector_store/test_weaviate.py
rename to tests/integration_tests/vector_store/test_weaviate.py
diff --git a/tests/tools/email/__init__.py b/tests/tools/email/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/tools/email/test_send_email.py b/tests/tools/email/test_send_email.py
new file mode 100644
index 000000000..16d802477
--- /dev/null
+++ b/tests/tools/email/test_send_email.py
@@ -0,0 +1,70 @@
+from unittest.mock import MagicMock
+
+import pytest
+import imaplib
+import time
+from email.message import EmailMessage
+
+from superagi.config.config import get_config
+from superagi.helper.imap_email import ImapEmail
+from superagi.tools.email import send_email
+from superagi.tools.email.send_email import SendEmailTool
+
+def test_send_to_draft(mocker):
+
+ mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config', autospec=True)
+ mock_get_config.side_effect = [
+ 'test_sender@test.com', # EMAIL_ADDRESS
+ 'password', # EMAIL_PASSWORD
+ 'Test Signature', # EMAIL_SIGNATURE
+ "Draft", # EMAIL_DRAFT_MODE_WITH_FOLDER
+ 'smtp_host', # EMAIL_SMTP_HOST
+ 'smtp_port' # EMAIL_SMTP_PORT
+ ]
+
+
+ # Mocking the ImapEmail call
+ mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail')
+ mock_imap_instance = mock_imap_email.return_value.imap_open.return_value
+
+ # Mocking the SMTP call
+ mock_smtp = mocker.patch('smtplib.SMTP')
+ smtp_instance = mock_smtp.return_value
+
+ # Test the SendEmailTool's execute method
+ send_email_tool = SendEmailTool()
+ result = send_email_tool._execute('mukunda@contlo.com', 'Test Subject', 'Test Body')
+
+ # Assert the return value
+ assert result == 'Email went to Draft'
+
+def test_send_to_mailbox(mocker):
+ # Mocking the get_config calls
+ mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config')
+ mock_get_config.side_effect = [
+ 'test_sender@test.com', # EMAIL_ADDRESS
+ 'password', # EMAIL_PASSWORD
+ 'Test Signature', # EMAIL_SIGNATURE
+ "YOUR_DRAFTS_FOLDER", # EMAIL_DRAFT_MODE_WITH_FOLDER
+ 'smtp_host', # EMAIL_SMTP_HOST
+ 'smtp_port' # EMAIL_SMTP_PORT
+ ]
+
+ # mock_get_config.return_value = 'True'
+ # Mocking the ImapEmail call
+ mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail')
+ mock_imap_instance = mock_imap_email.return_value.imap_open.return_value
+
+ # Mocking the SMTP call
+ mock_smtp = mocker.patch('smtplib.SMTP')
+ smtp_instance = mock_smtp.return_value
+
+ # Test the SendEmailTool's execute method
+ send_email_tool = SendEmailTool()
+ result = send_email_tool._execute('test_receiver@test.com', 'Test Subject', 'Test Body')
+
+ # Assert that the ImapEmail was not called (no draft mode)
+ mock_imap_email.assert_not_called()
+
+ # Assert the return value
+ assert result == 'Email was sent to test_receiver@test.com'
\ No newline at end of file
diff --git a/tests/tools/image_gen_test.py b/tests/tools/image_gen_test.py
deleted file mode 100644
index f96454545..000000000
--- a/tests/tools/image_gen_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import os
-import unittest
-from unittest.mock import patch, MagicMock
-
-from superagi.tools.image_generation.dalle_image_gen import ImageGenTool
-
-
-class TestImageGenTool(unittest.TestCase):
-
- @patch('openai.Image.create')
- @patch('requests.get')
- @patch('superagi.tools.image_generation.dalle_image_gen.get_config')
- def test_image_gen_tool_execute(self, mock_get_config, mock_requests_get, mock_openai_create):
- # Setup
- tool = ImageGenTool()
- prompt = 'Artificial Intelligence'
- image_names = ['image1.png', 'image2.png']
- size = 512
- num = 2
-
- # Mock responses
- mock_get_config.return_value = "/tmp"
- mock_openai_create.return_value = MagicMock(_previous=MagicMock(data=[
- {"url": "https://example.com/image1.png"},
- {"url": "https://example.com/image2.png"}
- ]))
- mock_requests_get.return_value.content = b"image_data"
-
- # Run the method under test
- response = tool._execute(prompt, image_names, size, num)
-
- # Assert the method ran correctly
- self.assertEqual(response, "Images downloaded successfully")
- for image_name in image_names:
- path = "/tmp/" + image_name
- self.assertTrue(os.path.exists(path))
- with open(path, "rb") as file:
- self.assertEqual(file.read(), b"image_data")
-
- # Clean up
- for image_name in image_names:
- os.remove("/tmp/" + image_name)
-
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/tests/tools/stable_diffusion_image_gen_test.py b/tests/tools/stable_diffusion_image_gen_test.py
deleted file mode 100644
index ae2ebc673..000000000
--- a/tests/tools/stable_diffusion_image_gen_test.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-import unittest
-from unittest.mock import patch, MagicMock
-from PIL import Image
-from io import BytesIO
-import base64
-from superagi.config.config import get_config
-
-from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool
-
-
-class TestStableDiffusionImageGenTool(unittest.TestCase):
-
- @patch('requests.post')
- @patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config')
- def test_stable_diffusion_image_gen_tool_execute(self, mock_get_config, mock_requests_post):
- # Setup
- tool = StableDiffusionImageGenTool()
- prompt = 'Artificial Intelligence'
- image_names = ['image1.png', 'image2.png']
- height = 512
- width = 512
- num = 2
- steps = 50
-
- # Create a temporary directory for image storage
- temp_dir = get_config("RESOURCES_OUTPUT_ROOT_DIR")
-
- # Mock responses
- mock_configs = {"STABILITY_API_KEY": "api_key", "ENGINE_ID": "engine_id", "RESOURCES_OUTPUT_ROOT_DIR": temp_dir}
- mock_get_config.side_effect = lambda k: mock_configs[k]
-
- # Prepare sample image bytes
- img = Image.new("RGB", (width, height), "white")
- buffer = BytesIO()
- img.save(buffer, "PNG")
- buffer.seek(0)
- img_data = buffer.getvalue()
- encoded_image_data = base64.b64encode(img_data).decode()
-
- # Use the proper base64-encoded string
- mock_requests_post.return_value = MagicMock(status_code=200, json=lambda: {
- "artifacts": [
- {"base64": encoded_image_data},
- {"base64": encoded_image_data}
- ]
- })
-
- # Run the method under test
- response = tool._execute(prompt, image_names, width, height, num, steps)
- self.assertEqual(response, f"Images downloaded successfully")
-
- for image_name in image_names:
- path = os.path.join(temp_dir, image_name)
- self.assertTrue(os.path.exists(path))
- with open(path, "rb") as file:
- self.assertEqual(file.read(), img_data)
-
- # Clean up
- for image_name in image_names:
- os.remove(os.path.join(temp_dir, image_name))
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/agent/__init__.py b/tests/unit_tests/agent/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/agent/test_task_queue.py b/tests/unit_tests/agent/test_task_queue.py
similarity index 98%
rename from tests/agent/test_task_queue.py
rename to tests/unit_tests/agent/test_task_queue.py
index 9627231be..85dd20f31 100644
--- a/tests/agent/test_task_queue.py
+++ b/tests/unit_tests/agent/test_task_queue.py
@@ -47,5 +47,6 @@ def test_get_last_task_details(self, mock_get_last_task_details):
self.queue.get_last_task_details()
mock_get_last_task_details.assert_called()
+
if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/tests/unit_tests/agent_permissions/__init__.py b/tests/unit_tests/agent_permissions/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/agent_permissions/test_check_permission_in_restricted_mode.py b/tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py
similarity index 100%
rename from tests/agent_permissions/test_check_permission_in_restricted_mode.py
rename to tests/unit_tests/agent_permissions/test_check_permission_in_restricted_mode.py
diff --git a/tests/agent_permissions/test_handle_wait_for_permission.py b/tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py
similarity index 100%
rename from tests/agent_permissions/test_handle_wait_for_permission.py
rename to tests/unit_tests/agent_permissions/test_handle_wait_for_permission.py
diff --git a/tests/unit_tests/helper/__init__.py b/tests/unit_tests/helper/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/helper/test_github_helper.py b/tests/unit_tests/helper/test_github_helper.py
similarity index 100%
rename from tests/helper/test_github_helper.py
rename to tests/unit_tests/helper/test_github_helper.py
diff --git a/tests/helper/test_json_cleaner.py b/tests/unit_tests/helper/test_json_cleaner.py
similarity index 97%
rename from tests/helper/test_json_cleaner.py
rename to tests/unit_tests/helper/test_json_cleaner.py
index be64eaa6f..8579a9900 100644
--- a/tests/helper/test_json_cleaner.py
+++ b/tests/unit_tests/helper/test_json_cleaner.py
@@ -40,4 +40,4 @@ def test_clean_newline_spaces_json():
def test_has_newline_in_string():
test_str = r'{key: "value\n"\n \n}'
result = JsonCleaner.check_and_clean_json(test_str)
- assert result == '{key: "value\\n"}'
+ assert result == '{key: "value"}'
diff --git a/tests/unit_tests/helper/test_resource_helper.py b/tests/unit_tests/helper/test_resource_helper.py
new file mode 100644
index 000000000..6c24cecb2
--- /dev/null
+++ b/tests/unit_tests/helper/test_resource_helper.py
@@ -0,0 +1,39 @@
+import pytest
+from unittest.mock import patch
+from superagi.helper.resource_helper import ResourceHelper
+
+def test_make_written_file_resource(mocker):
+ mocker.patch('os.getcwd', return_value='/')
+ # mocker.patch('os.getcwd', return_value='/')
+ mocker.patch('os.makedirs', return_value=None)
+ mocker.patch('os.path.getsize', return_value=1000)
+ mocker.patch('os.path.splitext', return_value=("", ".txt"))
+ mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/', 'local', None])
+
+ with patch('superagi.helper.resource_helper.logger') as logger_mock:
+ result = ResourceHelper.make_written_file_resource('test.txt', 1, 'INPUT')
+
+ assert result.name == 'test.txt'
+ assert result.path == '/1/test.txt'
+ assert result.storage_type == 'local'
+ assert result.size == 1000
+ assert result.type == 'application/txt'
+ assert result.channel == 'OUTPUT'
+ assert result.agent_id == 1
+
+def test_get_resource_path(mocker):
+ mocker.patch('os.getcwd', return_value='/')
+ mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/'])
+
+ result = ResourceHelper.get_resource_path('test.txt')
+
+ assert result == '/test.txt'
+
+def test_get_agent_resource_path(mocker):
+ mocker.patch('os.getcwd', return_value='/')
+ mocker.patch('os.makedirs')
+ mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/'])
+
+ result = ResourceHelper.get_agent_resource_path('test.txt', 1)
+
+ assert result == '/1/test.txt'
diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/models/test_agent_execution_feed.py b/tests/unit_tests/models/test_agent_execution_feed.py
new file mode 100644
index 000000000..15924c551
--- /dev/null
+++ b/tests/unit_tests/models/test_agent_execution_feed.py
@@ -0,0 +1,27 @@
+import pytest
+from unittest.mock import Mock, create_autospec
+from sqlalchemy.orm import Session
+from superagi.models.agent_execution_feed import AgentExecutionFeed
+
+
+def test_get_last_tool_response():
+ mock_session = create_autospec(Session)
+ agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system')
+ agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system')
+
+ mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2]
+
+ result = AgentExecutionFeed.get_last_tool_response(mock_session, 2)
+
+ assert result == agent_execution_feed_1.feed # as agent_execution_feed_1 should be the latest based on created_at
+
+
+def test_get_last_tool_response_with_tool_name():
+ mock_session = create_autospec(Session)
+ agent_execution_feed_1 = AgentExecutionFeed(id=1, agent_execution_id=2, feed="Tool test1", role='system')
+ agent_execution_feed_2 = AgentExecutionFeed(id=2, agent_execution_id=2, feed="Tool test2", role='system')
+
+ mock_session.query().filter().order_by().all.return_value = [agent_execution_feed_1, agent_execution_feed_2]
+
+ result = AgentExecutionFeed.get_last_tool_response(mock_session, 2, "test2")
+ assert result == agent_execution_feed_2.feed
diff --git a/tests/unit_tests/resource_manager/__init__.py b/tests/unit_tests/resource_manager/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/resource_manager/test_resource_manager.py b/tests/unit_tests/resource_manager/test_resource_manager.py
new file mode 100644
index 000000000..a4630f0d5
--- /dev/null
+++ b/tests/unit_tests/resource_manager/test_resource_manager.py
@@ -0,0 +1,37 @@
+import pytest
+from unittest.mock import Mock, patch
+from superagi.models.resource import Resource
+from superagi.helper.resource_helper import ResourceHelper
+from superagi.helper.s3_helper import S3Helper
+from superagi.lib.logger import logger
+
+from superagi.resource_manager.manager import ResourceManager
+
+@pytest.fixture
+def resource_manager():
+ session_mock = Mock()
+ resource_manager = ResourceManager(session_mock)
+ #resource_manager.agent_id = 1 # replace with actual value
+ return resource_manager
+
+
+def test_write_binary_file(resource_manager):
+ with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \
+ patch.object(ResourceHelper, 'make_written_file_resource',
+ return_value=Resource(name='test.png', storage_type='S3')), \
+ patch.object(S3Helper, 'upload_file'), \
+ patch.object(logger, 'info') as logger_mock:
+ result = resource_manager.write_binary_file('test.png', b'data')
+ assert result == "Binary test.png saved successfully"
+ logger_mock.assert_called_once_with("Binary test.png saved successfully")
+
+
+def test_write_file(resource_manager):
+ with patch.object(ResourceHelper, 'get_resource_path', return_value='test_path'), \
+ patch.object(ResourceHelper, 'make_written_file_resource',
+ return_value=Resource(name='test.txt', storage_type='S3')), \
+ patch.object(S3Helper, 'upload_file'), \
+ patch.object(logger, 'info') as logger_mock:
+ result = resource_manager.write_file('test.txt', 'content')
+ assert result == "test.txt saved successfully"
+ logger_mock.assert_called_once_with("test.txt saved successfully")
diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/tools/test_dalle_image_gen.py b/tests/unit_tests/tools/test_dalle_image_gen.py
new file mode 100644
index 000000000..82653d5a2
--- /dev/null
+++ b/tests/unit_tests/tools/test_dalle_image_gen.py
@@ -0,0 +1,27 @@
+from unittest.mock import Mock, patch
+import pytest
+from superagi.tools.image_generation.dalle_image_gen import DalleImageGenTool
+
+
+class MockBaseLlm:
+ def generate_image(self, prompt, size, num):
+ return Mock(_previous={"data": [{"url": f"https://example.com/image_{i}.png"} for i in range(num)]})
+
+
+class TestDalleImageGenTool:
+
+ @pytest.fixture
+ def tool(self):
+ tool = DalleImageGenTool()
+ tool.llm = MockBaseLlm()
+ response_mock = Mock()
+ tool.resource_manager = response_mock
+ return tool
+
+ @patch("requests.get")
+ def test_execute(self, mock_get, tool):
+ mock_get.return_value = Mock(content=b"fake image data")
+ response = tool._execute("test prompt", ["test1.png", "test2.png"], size=512, num=2)
+ assert response == "Images downloaded successfully"
+ mock_get.assert_called_with("https://example.com/image_1.png")
+ assert tool.resource_manager.write_binary_file.call_count == 2
diff --git a/tests/unit_tests/tools/test_send_email.py b/tests/unit_tests/tools/test_send_email.py
new file mode 100644
index 000000000..b0d80928e
--- /dev/null
+++ b/tests/unit_tests/tools/test_send_email.py
@@ -0,0 +1,61 @@
+from superagi.tools.email.send_email import SendEmailTool
+import pytest
+
+def test_send_to_draft(mocker):
+
+ mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config', autospec=True)
+ mock_get_config.side_effect = [
+ 'test_sender@test.com', # EMAIL_ADDRESS
+ 'password', # EMAIL_PASSWORD
+ 'Test Signature', # EMAIL_SIGNATURE
+ "Draft", # EMAIL_DRAFT_MODE_WITH_FOLDER
+ 'smtp_host', # EMAIL_SMTP_HOST
+ 'smtp_port' # EMAIL_SMTP_PORT
+ ]
+
+
+ # Mocking the ImapEmail call
+ mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail')
+ mock_imap_instance = mock_imap_email.return_value.imap_open.return_value
+
+ # Mocking the SMTP call
+ mock_smtp = mocker.patch('smtplib.SMTP')
+ smtp_instance = mock_smtp.return_value
+
+ # Test the SendEmailTool's execute method
+ send_email_tool = SendEmailTool()
+ result = send_email_tool._execute('mukunda@contlo.com', 'Test Subject', 'Test Body')
+
+ # Assert the return value
+ assert result == 'Email went to Draft'
+
+def test_send_to_mailbox(mocker):
+ # Mocking the get_config calls
+ mock_get_config = mocker.patch('superagi.tools.email.send_email.get_config')
+ mock_get_config.side_effect = [
+ 'test_sender@test.com', # EMAIL_ADDRESS
+ 'password', # EMAIL_PASSWORD
+ 'Test Signature', # EMAIL_SIGNATURE
+ "YOUR_DRAFTS_FOLDER", # EMAIL_DRAFT_MODE_WITH_FOLDER
+ 'smtp_host', # EMAIL_SMTP_HOST
+ 'smtp_port' # EMAIL_SMTP_PORT
+ ]
+
+ # mock_get_config.return_value = 'True'
+ # Mocking the ImapEmail call
+ mock_imap_email = mocker.patch('superagi.tools.email.send_email.ImapEmail')
+ mock_imap_instance = mock_imap_email.return_value.imap_open.return_value
+
+ # Mocking the SMTP call
+ mock_smtp = mocker.patch('smtplib.SMTP')
+ smtp_instance = mock_smtp.return_value
+
+ # Test the SendEmailTool's execute method
+ send_email_tool = SendEmailTool()
+ result = send_email_tool._execute('test_receiver@test.com', 'Test Subject', 'Test Body')
+
+ # Assert that the ImapEmail was not called (no draft mode)
+ mock_imap_email.assert_not_called()
+
+ # Assert the return value
+ assert result == 'Email was sent to test_receiver@test.com'
\ No newline at end of file
diff --git a/tests/unit_tests/tools/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py
new file mode 100644
index 000000000..6d2dc75c0
--- /dev/null
+++ b/tests/unit_tests/tools/test_stable_diffusion_image_gen.py
@@ -0,0 +1,51 @@
+import base64
+from io import BytesIO
+from unittest.mock import patch, Mock
+
+import pytest
+from PIL import Image
+
+from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool
+
+
+def create_sample_image_base64():
+ image = Image.new('RGBA', size=(50, 50), color=(73, 109, 137))
+ byte_arr = BytesIO()
+ image.save(byte_arr, format='PNG')
+ encoded_image = base64.b64encode(byte_arr.getvalue())
+ return encoded_image.decode('utf-8')
+
+
+@pytest.fixture
+def stable_diffusion_tool():
+ with patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config') as get_config_mock, \
+ patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \
+ patch('superagi.tools.image_generation.stable_diffusion_image_gen.ResourceManager') as resource_manager_mock:
+ get_config_mock.return_value = 'fake_api_key'
+
+ # Create a mock response object
+ response_mock = Mock()
+ response_mock.status_code = 200
+ response_mock.json.return_value = {
+ 'artifacts': [{'base64': create_sample_image_base64()} for _ in range(2)]
+ }
+ post_mock.return_value = response_mock
+
+ resource_manager_mock.write_binary_file.return_value = None
+
+ yield
+
+def test_execute(stable_diffusion_tool):
+ tool = StableDiffusionImageGenTool()
+ tool.resource_manager = Mock()
+ result = tool._execute('prompt', ['img1.png', 'img2.png'])
+
+ assert result == 'Images downloaded and saved successfully'
+ tool.resource_manager.write_binary_file.assert_called()
+
+def test_call_stable_diffusion(stable_diffusion_tool):
+ tool = StableDiffusionImageGenTool()
+ response = tool.call_stable_diffusion('fake_api_key', 512, 512, 2, 'prompt', 50)
+
+ assert response.status_code == 200
+ assert 'artifacts' in response.json()
\ No newline at end of file
diff --git a/tests/unit_tests/tools/test_write_code.py b/tests/unit_tests/tools/test_write_code.py
new file mode 100644
index 000000000..76f3a5877
--- /dev/null
+++ b/tests/unit_tests/tools/test_write_code.py
@@ -0,0 +1,37 @@
+from unittest.mock import Mock, patch
+import pytest
+
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
+from superagi.tools.code.write_code import CodingTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+
+
+class MockBaseLlm:
+ def chat_completion(self, messages, max_tokens):
+ return {"content": "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```"}
+
+ def get_model(self):
+ return "gpt-3.5-turbo"
+
+class TestCodingTool:
+
+ @pytest.fixture
+ def tool(self):
+ tool = CodingTool()
+ tool.llm = MockBaseLlm()
+ tool.resource_manager = Mock(spec=ResourceManager)
+ tool.tool_response_manager = Mock(spec=ToolResponseQueryManager)
+ return tool
+
+ def test_execute(self, tool):
+ tool.resource_manager.write_file.return_value = "File write successful"
+ tool.tool_response_manager.get_last_response.return_value = "Mocked Spec"
+
+ response = tool._execute("Test spec description")
+ assert response == "File1.py\n```python\nprint('Hello World')\n```\n\nFile2.py\n```python\nprint('Hello again')\n```\n Codes generated and saved successfully in File1.py, File2.py"
+
+ tool.resource_manager.write_file.assert_any_call("README.md", 'File1.py\n')
+ tool.resource_manager.write_file.assert_any_call("File1.py", "print('Hello World')\n")
+ tool.resource_manager.write_file.assert_any_call("File2.py", "print('Hello again')\n")
+ tool.tool_response_manager.get_last_response.assert_called_once_with("WriteSpecTool")
\ No newline at end of file
diff --git a/tests/unit_tests/tools/test_write_spec.py b/tests/unit_tests/tools/test_write_spec.py
new file mode 100644
index 000000000..05f85ed5c
--- /dev/null
+++ b/tests/unit_tests/tools/test_write_spec.py
@@ -0,0 +1,29 @@
+from unittest.mock import Mock
+
+import pytest
+
+from superagi.tools.code.write_spec import WriteSpecTool
+
+
+class MockBaseLlm:
+ def chat_completion(self, messages, max_tokens):
+ return {"content": "Generated specification"}
+
+ def get_model(self):
+ return "gpt-3.5-turbo"
+
+class TestWriteSpecTool:
+
+ @pytest.fixture
+ def tool(self):
+ tool = WriteSpecTool()
+ tool.llm = MockBaseLlm()
+ tool.resource_manager = Mock()
+ return tool
+
+ def test_execute(self, tool):
+ tool.resource_manager.write_file = Mock()
+ tool.resource_manager.write_file.return_value = "File write successful"
+ response = tool._execute("Test task description", "test_spec_file.txt")
+ assert response == "Generated specificationSpecification generated and saved successfully"
+ tool.resource_manager.write_file.assert_called_once_with("test_spec_file.txt", "Generated specification")
diff --git a/tests/unit_tests/tools/test_write_test.py b/tests/unit_tests/tools/test_write_test.py
new file mode 100644
index 000000000..3aba9c788
--- /dev/null
+++ b/tests/unit_tests/tools/test_write_test.py
@@ -0,0 +1,47 @@
+import pytest
+from unittest.mock import Mock, patch
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.manager import ResourceManager
+from superagi.lib.logger import logger
+from superagi.tools.code.write_test import WriteTestTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+
+
+def test_write_test_tool_init():
+ tool = WriteTestTool()
+ assert tool.llm is None
+ assert tool.agent_id is None
+ assert tool.name == "WriteTestTool"
+ assert tool.description is not None
+ assert tool.goals == []
+ assert tool.resource_manager is None
+
+
+@patch('superagi.tools.code.write_test.logger')
+@patch('superagi.tools.code.write_test.TokenCounter')
+def test_write_test_tool_execute(mock_token_counter, mock_logger):
+ # Given
+ mock_llm = Mock(spec=BaseLlm)
+ mock_llm.get_model.return_value = None
+ mock_llm.chat_completion.return_value = {"content": "```python\nsample_code\n```"}
+ mock_token_counter.count_message_tokens.return_value = 0
+ mock_token_counter.token_limit.return_value = 100
+
+ mock_resource_manager = Mock(spec=ResourceManager)
+ mock_resource_manager.write_file.return_value = "No error"
+
+ mock_tool_response_manager = Mock(spec=ToolResponseQueryManager)
+ mock_tool_response_manager.get_last_response.return_value = ""
+
+ tool = WriteTestTool()
+ tool.llm = mock_llm
+ tool.resource_manager = mock_resource_manager
+ tool.tool_response_manager = mock_tool_response_manager
+
+ # When
+ result = tool._execute("spec", "test_file")
+
+ # Then
+ mock_llm.chat_completion.assert_called_once()
+ mock_resource_manager.write_file.assert_called_once_with("test_file", "python\nsample_code")
+ assert "Tests generated and saved successfully in test_file" in result