diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py index 2b2d0ba1f..85e1462d0 100644 --- a/superagi/helper/s3_helper.py +++ b/superagi/helper/s3_helper.py @@ -8,9 +8,8 @@ import json - class S3Helper: - def __init__(self, bucket_name = get_config("BUCKET_NAME")): + def __init__(self, bucket_name=get_config("BUCKET_NAME")): """ Initialize the S3Helper class. Using the AWS credentials from the configuration file, create a boto3 client. @@ -84,7 +83,7 @@ def get_json_file(self, path): """ try: obj = self.s3.get_object(Bucket=self.bucket_name, Key=path) - s3_response = obj['Body'].read().decode('utf-8') + s3_response = obj['Body'].read().decode('utf-8') return json.loads(s3_response) except: raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") @@ -108,9 +107,43 @@ def delete_file(self, path): logger.info("File deleted from S3 successfully!") except: raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") - + def upload_file_content(self, content, file_path): try: self.s3.put_object(Bucket=self.bucket_name, Key=file_path, Body=content) except: - raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") \ No newline at end of file + raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.") + + def get_download_url_of_resources(self, db_resources_arr): + s3 = boto3.client( + 's3', + aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"), + ) + response_obj = {} + for db_resource in db_resources_arr: + response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=db_resource.path) + content = response["Body"].read() + bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME") + file_name = db_resource.path.split('/')[-1] + file_name = ''.join(char for char in file_name if char != "`") + object_key = f"public_resources/run_id{db_resource.agent_execution_id}/{file_name}" + s3.put_object(Bucket=bucket_name, Key=object_key, Body=content) + file_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}" + resource_execution_id = db_resource.agent_execution_id + if resource_execution_id in response_obj: + response_obj[resource_execution_id].append(file_url) + else: + response_obj[resource_execution_id] = [file_url] + return response_obj + + def list_files_from_s3(self, file_path): + file_path = "resources" + file_path + logger.info(f"Listing files from s3 with prefix: {file_path}") + response = self.s3.list_objects_v2(Bucket=get_config("BUCKET_NAME"), Prefix=file_path) + + if 'Contents' in response: + file_list = [obj['Key'] for obj in response['Contents']] + return file_list + + raise Exception(f"Error listing files from s3") diff --git a/superagi/tools/file/list_files.py b/superagi/tools/file/list_files.py index 5f2414329..7290e53db 100644 --- a/superagi/tools/file/list_files.py +++ b/superagi/tools/file/list_files.py @@ -4,8 +4,11 @@ from pydantic import BaseModel, Field from superagi.helper.resource_helper import ResourceHelper +from superagi.helper.s3_helper import S3Helper from superagi.tools.base_tool import BaseTool from superagi.models.agent import Agent +from superagi.types.storage_types import StorageType +from superagi.config.config import get_config class ListFileInput(BaseModel): @@ -52,6 +55,8 @@ def _execute(self): return input_files #+ output_files def list_files(self, directory): + if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3: + return S3Helper().list_files_from_s3(directory) found_files = [] for root, dirs, files in os.walk(directory): for file in files: diff --git a/tests/unit_tests/helper/test_s3_helper.py b/tests/unit_tests/helper/test_s3_helper.py index bceb5d67f..d476c7a7c 100644 --- a/tests/unit_tests/helper/test_s3_helper.py +++ b/tests/unit_tests/helper/test_s3_helper.py @@ -102,3 +102,29 @@ def test_delete_file_fail(s3helper_object): s3helper_object.s3.delete_object = MagicMock(side_effect=Exception()) with pytest.raises(HTTPException): s3helper_object.delete_file('path') + + +def test_list_files_from_s3(s3helper_object): + s3helper_object.s3.list_objects_v2 = MagicMock(return_value={ + 'Contents': [{'Key': 'path/to/file1.txt'}, {'Key': 'path/to/file2.jpg'}] + }) + + file_list = s3helper_object.list_files_from_s3('path/to/') + + assert len(file_list) == 2 + assert 'path/to/file1.txt' in file_list + assert 'path/to/file2.jpg' in file_list + + +def test_list_files_from_s3_no_contents(s3helper_object): + s3helper_object.s3.list_objects_v2 = MagicMock(return_value={}) + + with pytest.raises(Exception): + s3helper_object.list_files_from_s3('path/to/') + + +def test_list_files_from_s3_raises_exception(s3helper_object): + s3helper_object.s3.list_objects_v2 = MagicMock(side_effect=Exception("An error occurred")) + + with pytest.raises(Exception): + s3helper_object.list_files_from_s3('path/to/')