diff --git a/.github/workflows/increment_version_dev.yaml b/.github/workflows/increment_version_dev.yaml index b5303b077..e1734c333 100644 --- a/.github/workflows/increment_version_dev.yaml +++ b/.github/workflows/increment_version_dev.yaml @@ -24,11 +24,11 @@ jobs: - name: Increment versions in pyproject.toml run: | set -x - + echo "Incrementing versions..." find . -name "pyproject.toml" | while read -r pyproject; do echo "Processing $pyproject" - + # Extract current version CURRENT_VERSION=$(python -c " import tomlkit @@ -43,31 +43,30 @@ jobs: print(f'Error reading version from {pyproject}: {e}', end='') exit(1) ") - + echo "Extracted CURRENT_VERSION: $CURRENT_VERSION" - + if [ -z "$CURRENT_VERSION" ]; then echo "Error: Could not extract the current version from $pyproject" cat "$pyproject" continue fi - + # Increment version BASE_VERSION=$(echo "$CURRENT_VERSION" | sed -E 's/(.*)-dev.*/\1/') DEV_PART=$(echo "$CURRENT_VERSION" | grep -oE 'dev[0-9]+$' | grep -oE '[0-9]+') - + + # Fallback if no DEV_PART is found if [ -z "$DEV_PART" ]; then DEV_PART=0 fi - + NEW_DEV_PART=$((DEV_PART + 1)) NEW_VERSION="${BASE_VERSION}-dev${NEW_DEV_PART}" - + echo "Updating version from $CURRENT_VERSION to $NEW_VERSION" done - - - name: Commit changes run: | git config user.name "github-actions[bot]" diff --git a/.github/workflows/sequence_publish.yaml b/.github/workflows/sequence_publish.yaml index 23a88d990..f219023f4 100644 --- a/.github/workflows/sequence_publish.yaml +++ b/.github/workflows/sequence_publish.yaml @@ -93,7 +93,7 @@ jobs: - uses: actions/checkout@v4 - name: Wait for swarmauri - run: sleep 60 + run: sleep 120 - name: Set up Python 3.12 uses: actions/setup-python@v5 diff --git a/.github/workflows/test_changed_files.yaml b/.github/workflows/test_changed_files.yaml index 5846c5c09..19104ba24 100644 --- a/.github/workflows/test_changed_files.yaml +++ b/.github/workflows/test_changed_files.yaml @@ -99,6 +99,9 @@ jobs: run-tests: needs: detect-changed-files runs-on: ubuntu-latest + permissions: + issues: write + contents: read if: ${{ needs.detect-changed-files.outputs.matrix != '[]' }} strategy: fail-fast: false @@ -134,9 +137,26 @@ jobs: cd pkgs/${{ matrix.package_tests.package }} poetry install --no-cache --all-extras -vv - - name: Run all tests for the package + - name: Run tests and save results run: | echo "Running tests for package: ${{ matrix.package_tests.package }}" echo "Test files: ${{ matrix.package_tests.tests }}" cd pkgs/${{ matrix.package_tests.package }} - poetry run pytest ${{ matrix.package_tests.tests }} + poetry run pytest -v ${{ matrix.package_tests.tests }} -n 4 --dist=loadfile --tb=short --json-report --json-report-file=pytest_results.json || true + + - name: Classify test results + run: | + python scripts/classify_json_results.py pkgs/${{ matrix.package_tests.package }}/pytest_results.json --required-passed ge:50 --required-skipped lt:30 + continue-on-error: false + + - name: Process test results and manage issues + if: | + github.event.pull_request.head.repo.full_name == github.repository && always() + env: + DEEPINFRA_API_KEY: ${{ secrets.DEEPINFRA_API_KEY }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + REPO: ${{ github.repository }} + run: | + cd pkgs/swarmauri # Change to the directory containing pyproject.toml + poetry run python ../../scripts/rag_issue_manager.py --results-file=pytest_results.json --package=${{ matrix.package_tests.package }} + diff --git a/pkgs/community/pyproject.toml b/pkgs/community/pyproject.toml index a82494195..76584e7d0 100644 --- a/pkgs/community/pyproject.toml +++ b/pkgs/community/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-community" -version = "0.5.2.dev20" +version = "0.5.3.dev5" description = "This repository includes Swarmauri community components." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -15,48 +15,74 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<3.13" -captcha = "*" -chromadb = "*" -duckdb = "*" -folium = "*" -gensim = "*" -#google-generativeai = "*" -gradio = "*" -leptonai = "0.22.0" -neo4j = "*" -nltk = "*" -#openai = "^1.52.0" -pandas = "*" -psutil = "*" -pygithub = "*" -python-dotenv = "*" -qrcode = "*" -redis = "^4.0" -scikit-learn="^1.4.2" -swarmauri = "==0.5.2" -textstat = "*" -transformers = ">=4.45.0" -typing_extensions = "*" -tiktoken = "*" -pymupdf = "*" -annoy = "*" -qdrant_client = "*" -weaviate = "*" -pinecone-client = { version = "*", extras = ["grpc"] } -PyPDF2 = "*" -pypdftk = "*" -weaviate-client = "*" -protobuf = "^3.20.0" -# Pacmap requires specific version of numba -#numba = ">=0.59.0" -#pacmap = "==0.7.3" +captcha = "^0.6.0" + +# We should remove and only rely on httpx +requests = "^2.32.3" + +chromadb = { version = "^0.5.17", optional = true } +duckdb = { version = "^1.1.1", optional = true } +folium = { version = "^0.18.0", optional = true } +gensim = { version = "^4.3.3", optional = true } +gradio = { version = "^5.4.0", optional = true } +leptonai = { version = "^0.22.0", optional = true } +neo4j = { version = "^5.25.0", optional = true } +nltk = { version = "^3.9.1", optional = true } +pandas = "^2.2.3" +psutil = { version = "^6.1.0", optional = true } +pygithub = { version = "^2.4.0", optional = true } +qrcode = { version = "^8.0", optional = true } +redis = { version = "^4.0", optional = true } +swarmauri = "==0.5.3.dev5" +textstat = { version = "^0.7.4", optional = true } +transformers = { version = ">=4.45.0", optional = true } +typing_extensions = "^4.12.2" +tiktoken = { version = "^0.8.0", optional = true } +PyMuPDF = { version = "^1.24.12", optional = true } +annoy = { version = "^1.17.3", optional = true } +qdrant-client = { version = "^1.12.0", optional = true } +pinecone-client = { version = "^5.0.1", optional = true, extras = ["grpc"] } +pypdf = { version = "^5.0.1", optional = true } +pypdftk = { version = "^0.5", optional = true } +weaviate-client = { version = "^4.9.2", optional = true } +textblob = { version = "^0.18.0", optional = true } +torch = { version = "^2.4.1", optional = true} +scikit-learn = { version = "^1.5.2", optional = true } +#protobuf = { version = "^3.20.0", optional = true } + +[tool.poetry.extras] +# Grouped optional dependencies +nlp = ["nltk", "gensim", "textstat", "textblob", "torch", "scikit-learn"] +ml_toolkits = ["transformers", "annoy"] +visualization = ["folium"] +storage = ["redis", "duckdb", "neo4j", "chromadb", "qdrant-client", "weaviate-client", "pinecone-client"] +document_processing = ["pypdf", "PyMuPDF", "pypdftk"] +cloud_integration = ["psutil", "qrcode", "pygithub"] +gradio = ["gradio"] +model_clients = ["leptonai"] +tiktoken = ["tiktoken"] +# Full installation +full = [ + "nltk", "gensim", "textstat", "textblob", "torch", "scikit-learn", + "transformers", "annoy", + "folium", + "redis", "duckdb", "neo4j", "chromadb", "qdrant-client", "weaviate-client", "pinecone-client", + "pypdf", "PyMuPDF", "pypdftk", + "psutil", "qrcode", "pygithub", + "gradio", + "leptonai", + "tiktoken" +] [tool.poetry.dev-dependencies] -flake8 = "^7.0" # Add flake8 as a development dependency -pytest = "^8.0" # Ensure pytest is also added if you run tests +flake8 = "^7.0" +pytest = "^8.0" pytest-asyncio = ">=0.24.0" pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" +python-dotenv = "*" +requests = "^2.32.3" [build-system] requires = ["poetry-core>=1.0.0"] @@ -70,12 +96,10 @@ markers = [ "unit: Unit tests", "integration: Integration tests", "acceptance: Acceptance tests", - "experimental: Experimental tests", + "experimental: Experimental tests" ] - log_cli = true log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" log_cli_date_format = "%Y-%m-%d %H:%M:%S" - asyncio_default_fixture_loop_scope = "function" diff --git a/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py b/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py index a44ceb5c2..941fc5618 100644 --- a/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/document_stores/concrete/__init__.py @@ -1,3 +1,10 @@ -from swarmauri_community.document_stores.concrete.RedisDocumentStore import ( - RedisDocumentStore, -) +from swarmauri.utils._lazy_import import _lazy_import + +documents_stores_files = [ + ("swarmauri_community.documents_stores.concrete.RedisDocumentStore", "RedisDocumentStore"), +] + +for module_name, class_name in documents_stores_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in documents_stores_files] diff --git a/pkgs/community/swarmauri_community/embeddings/__init__.py b/pkgs/community/swarmauri_community/embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/community/swarmauri_community/embeddings/base/__init__.py b/pkgs/community/swarmauri_community/embeddings/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/Doc2VecEmbedding.py b/pkgs/community/swarmauri_community/embeddings/concrete/Doc2VecEmbedding.py similarity index 100% rename from pkgs/swarmauri/swarmauri/embeddings/concrete/Doc2VecEmbedding.py rename to pkgs/community/swarmauri_community/embeddings/concrete/Doc2VecEmbedding.py diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/MlmEmbedding.py b/pkgs/community/swarmauri_community/embeddings/concrete/MlmEmbedding.py similarity index 100% rename from pkgs/swarmauri/swarmauri/embeddings/concrete/MlmEmbedding.py rename to pkgs/community/swarmauri_community/embeddings/concrete/MlmEmbedding.py diff --git a/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py b/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py new file mode 100644 index 000000000..7bcc482d7 --- /dev/null +++ b/pkgs/community/swarmauri_community/embeddings/concrete/__init__.py @@ -0,0 +1,12 @@ +from swarmauri.utils._lazy_import import _lazy_import + + +embeddings_files = [ + ("swarmauri_community.embeddings.concrete.Doc2VecEmbedding", "Doc2VecEmbedding"), + ("swarmauri_community.embeddings.concrete.MlmEmbedding", "MlmEmbedding"), +] + +for module_name, class_name in embeddings_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in embeddings_files] diff --git a/pkgs/community/swarmauri_community/llms/concrete/__init__.py b/pkgs/community/swarmauri_community/llms/concrete/__init__.py index a8fa703c0..5c2266ce5 100644 --- a/pkgs/community/swarmauri_community/llms/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/llms/concrete/__init__.py @@ -1,4 +1,12 @@ -from swarmauri_community.llms.concrete.LeptonAIImgGenModel import LeptonAIImgGenModel -from swarmauri_community.llms.concrete.LeptonAIModel import LeptonAIModel +from swarmauri.utils._lazy_import import _lazy_import -__all__ = ["LeptonAIImgGenModel", "LeptonAIModel"] +llms_files = [ + ("swarmauri_community.llms.concrete.LeptonAIImgGenModel", "LeptonAIImgGenModel"), + ("swarmauri_community.llms.concrete.LeptonAIModel", "LeptonAIModel"), + ("swarmauri_community.llms.concrete.PytesseractImg2TextModel", "PytesseractImg2TextModel"), +] + +for module_name, class_name in llms_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in llms_files] diff --git a/pkgs/community/swarmauri_community/measurements/concrete/__init__.py b/pkgs/community/swarmauri_community/measurements/concrete/__init__.py index 276716315..a748c6a0e 100644 --- a/pkgs/community/swarmauri_community/measurements/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/measurements/concrete/__init__.py @@ -1,6 +1,11 @@ -from swarmauri_community.measurements.concrete.MutualInformationMeasurement import ( - MutualInformationMeasurement, -) -from swarmauri_community.measurements.concrete.TokenCountEstimatorMeasurement import ( - TokenCountEstimatorMeasurement, -) +from swarmauri.utils._lazy_import import _lazy_import + +measurement_files = [ + ("swarmauri_community.measurements.concrete.MutualInformationMeasurement", "MutualInformationMeasurement"), + ("swarmauri_community.measurements.concrete.TokenCountEstimatorMeasurement", "TokenCountEstimatorMeasurement"), +] + +for module_name, class_name in measurement_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in measurement_files] diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/BERTEmbeddingParser.py b/pkgs/community/swarmauri_community/parsers/concrete/BERTEmbeddingParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/BERTEmbeddingParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/BERTEmbeddingParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/EntityRecognitionParser.py b/pkgs/community/swarmauri_community/parsers/concrete/EntityRecognitionParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/EntityRecognitionParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/EntityRecognitionParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobNounParser.py b/pkgs/community/swarmauri_community/parsers/concrete/TextBlobNounParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobNounParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/TextBlobNounParser.py diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobSentenceParser.py b/pkgs/community/swarmauri_community/parsers/concrete/TextBlobSentenceParser.py similarity index 100% rename from pkgs/swarmauri/swarmauri/parsers/concrete/TextBlobSentenceParser.py rename to pkgs/community/swarmauri_community/parsers/concrete/TextBlobSentenceParser.py diff --git a/pkgs/community/swarmauri_community/parsers/concrete/__init__.py b/pkgs/community/swarmauri_community/parsers/concrete/__init__.py index b5d547c4e..3808b4733 100644 --- a/pkgs/community/swarmauri_community/parsers/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/parsers/concrete/__init__.py @@ -1,3 +1,16 @@ -from swarmauri_community.parsers.concrete.FitzPdfParser import PDFtoTextParser -from swarmauri_community.parsers.concrete.PyPDF2Parser import PyPDF2Parser -from swarmauri_community.parsers.concrete.PyPDFTKParser import PyPDFTKParser +from swarmauri.utils._lazy_import import _lazy_import + +parsers_files = [ + ("swarmauri_community.parsers.concrete.BERTEmbeddingParser", "BERTEmbeddingParser"), + ("swarmauri_community.parsers.concrete.EntityRecognitionParser", "EntityRecognitionParser"), + ("swarmauri_community.parsers.concrete.FitzPdfParser", "FitzPdfParser"), + ("swarmauri_community.parsers.concrete.PyPDF2Parser", "PyPDF2Parser"), + ("swarmauri_community.parsers.concrete.PyPDFTKParser", "PyPDFTKParser"), + ("swarmauri_community.parsers.concrete.TextBlobNounParser", "TextBlobNounParser"), + ("swarmauri_community.parsers.concrete.TextBlobSentenceParser", "TextBlobSentenceParser"), +] + +for module_name, class_name in parsers_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in parsers_files] diff --git a/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py b/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py index 000e57ffe..ec089ab66 100644 --- a/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/retrievers/concrete/__init__.py @@ -1,5 +1,10 @@ -# -*- coding: utf-8 -*- +from swarmauri.utils._lazy_import import _lazy_import -from swarmauri_community.retrievers.concrete.RedisDocumentRetriever import ( - RedisDocumentRetriever, -) +retriever_files = [ + ("swarmauri_community.retrievers.concrete.RedisDocumentRetriever", "RedisDocumentRetriever"), +] + +for module_name, class_name in retriever_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in retriever_files] diff --git a/pkgs/community/swarmauri_community/state/__init__.py b/pkgs/community/swarmauri_community/state/__init__.py new file mode 100644 index 000000000..3828350c1 --- /dev/null +++ b/pkgs/community/swarmauri_community/state/__init__.py @@ -0,0 +1 @@ +from swarmauri_community.state.concrete import * diff --git a/pkgs/community/swarmauri_community/state/concrete/ClipboardState.py b/pkgs/community/swarmauri_community/state/concrete/ClipboardState.py new file mode 100644 index 000000000..53f73bcde --- /dev/null +++ b/pkgs/community/swarmauri_community/state/concrete/ClipboardState.py @@ -0,0 +1,69 @@ +import pyperclip +from typing import Dict, Any +from swarmauri.state.base.StateBase import StateBase + + +class ClipboardState(StateBase): + """ + A concrete implementation of StateBase that uses the system clipboard to store and retrieve state data. + """ + + def read(self) -> Dict[str, Any]: + """ + Reads the current state from the clipboard as a dictionary. + """ + try: + clipboard_content = pyperclip.paste() + # Ensure the clipboard contains valid data (e.g., a JSON string that can be parsed) + if clipboard_content: + return eval( + clipboard_content + ) # Replace eval with JSON for safer parsing + return {} + except Exception as e: + raise ValueError(f"Failed to read state from clipboard: {e}") + + def write(self, data: Dict[str, Any]) -> None: + """ + Replaces the current state with the given data by copying it to the clipboard. + """ + try: + pyperclip.copy( + str(data) + ) # Convert dictionary to string for clipboard storage + except Exception as e: + raise ValueError(f"Failed to write state to clipboard: {e}") + + def update(self, data: Dict[str, Any]) -> None: + """ + Updates the current state with the given data by merging with clipboard content. + """ + try: + current_state = self.read() + current_state.update(data) + self.write(current_state) + except Exception as e: + raise ValueError(f"Failed to update state on clipboard: {e}") + + def reset(self) -> None: + """ + Resets the clipboard state to an empty dictionary. + """ + try: + self.write({}) + except Exception as e: + raise ValueError(f"Failed to reset clipboard state: {e}") + + def deep_copy(self) -> "ClipboardState": + """ + Creates a deep copy of the current state. In this context, simply returns a new ClipboardState with the same clipboard data. + """ + try: + current_state = self.read() + new_instance = ClipboardState() + new_instance.write(current_state) + return new_instance + except Exception as e: + raise ValueError( + f"Failed to create a deep copy of the clipboard state: {e}" + ) diff --git a/pkgs/community/swarmauri_community/state/concrete/__init__.py b/pkgs/community/swarmauri_community/state/concrete/__init__.py new file mode 100644 index 000000000..92405b4a9 --- /dev/null +++ b/pkgs/community/swarmauri_community/state/concrete/__init__.py @@ -0,0 +1,10 @@ +from swarmauri.utils._lazy_import import _lazy_import + +state_files = [ + ("swarmauri_community.state.concrete.ClipboardState", "ClipboardState"), +] + +for module_name, class_name in state_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in state_files] diff --git a/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py b/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py index 129ad1dc9..6aca27ddc 100644 --- a/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/toolkits/concrete/__init__.py @@ -1 +1,10 @@ -from swarmauri_community.toolkits.concrete.GithubToolkit import * +from swarmauri.utils._lazy_import import _lazy_import + +toolkits_files = [ + ("swarmauri_community.toolkits.concrete.GithubToolkit", "GithubToolkit"), +] + +for module_name, class_name in toolkits_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in toolkits_files] diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/TextLengthTool.py b/pkgs/community/swarmauri_community/tools/concrete/TextLengthTool.py similarity index 100% rename from pkgs/swarmauri/swarmauri/tools/concrete/TextLengthTool.py rename to pkgs/community/swarmauri_community/tools/concrete/TextLengthTool.py diff --git a/pkgs/community/swarmauri_community/tools/concrete/__init__.py b/pkgs/community/swarmauri_community/tools/concrete/__init__.py index 9db51c17f..fe50a5544 100644 --- a/pkgs/community/swarmauri_community/tools/concrete/__init__.py +++ b/pkgs/community/swarmauri_community/tools/concrete/__init__.py @@ -1,31 +1,33 @@ -from swarmauri_community.tools.concrete.CaptchaGeneratorTool import CaptchaGeneratorTool -from swarmauri_community.tools.concrete.DaleChallReadabilityTool import ( - DaleChallReadabilityTool, -) -from swarmauri_community.tools.concrete.DownloadPdfTool import DownloadPDFTool -from swarmauri_community.tools.concrete.EntityRecognitionTool import ( - EntityRecognitionTool, -) -from swarmauri_community.tools.concrete.FoliumTool import FoliumTool -from swarmauri_community.tools.concrete.GithubBranchTool import GithubBranchTool -from swarmauri_community.tools.concrete.GithubCommitTool import GithubCommitTool -from swarmauri_community.tools.concrete.GithubIssueTool import GithubIssueTool -from swarmauri_community.tools.concrete.GithubPRTool import GithubPRTool -from swarmauri_community.tools.concrete.GithubRepoTool import GithubRepoTool -from swarmauri_community.tools.concrete.GithubTool import GithubTool -from swarmauri_community.tools.concrete.GmailReadTool import GmailReadTool -from swarmauri_community.tools.concrete.GmailSendTool import GmailSendTool -from swarmauri_community.tools.concrete.LexicalDensityTool import LexicalDensityTool -from swarmauri_community.tools.concrete.PsutilTool import PsutilTool -from swarmauri_community.tools.concrete.QrCodeGeneratorTool import QrCodeGeneratorTool -from swarmauri_community.tools.concrete.SentenceComplexityTool import ( - SentenceComplexityTool, -) -from swarmauri_community.tools.concrete.SentimentAnalysisTool import ( - SentimentAnalysisTool, -) -from swarmauri_community.tools.concrete.SMOGIndexTool import SMOGIndexTool -from swarmauri_community.tools.concrete.WebScrapingTool import WebScrapingTool -from swarmauri_community.tools.concrete.ZapierHookTool import ZapierHookTool +from swarmauri.utils._lazy_import import _lazy_import -# from swarmauri_community.tools.concrete.PaCMAPTool import PaCMAPTool +tool_files = [ + ("swarmauri_community.tools.concrete.CaptchaGeneratorTool", "CaptchaGeneratorTool"), + ("swarmauri_community.tools.concrete.DaleChallReadabilityTool", "DaleChallReadabilityTool"), + ("swarmauri_community.tools.concrete.DownloadPdfTool", "DownloadPDFTool"), + ("swarmauri_community.tools.concrete.EntityRecognitionTool", "EntityRecognitionTool"), + ("swarmauri_community.tools.concrete.FoliumTool", "FoliumTool"), + ("swarmauri_community.tools.concrete.GithubBranchTool", "GithubBranchTool"), + ("swarmauri_community.tools.concrete.GithubCommitTool", "GithubCommitTool"), + ("swarmauri_community.tools.concrete.GithubIssueTool", "GithubIssueTool"), + ("swarmauri_community.tools.concrete.GithubPRTool", "GithubPRTool"), + ("swarmauri_community.tools.concrete.GithubRepoTool", "GithubRepoTool"), + ("swarmauri_community.tools.concrete.GithubTool", "GithubTool"), + ("swarmauri_community.tools.concrete.GmailReadTool", "GmailReadTool"), + ("swarmauri_community.tools.concrete.GmailSendTool", "GmailSendTool"), + ("swarmauri_community.tools.concrete.LexicalDensityTool", "LexicalDensityTool"), + ("swarmauri_community.tools.concrete.PsutilTool", "PsutilTool"), + ("swarmauri_community.tools.concrete.QrCodeGeneratorTool", "QrCodeGeneratorTool"), + ("swarmauri_community.tools.concrete.SentenceComplexityTool", "SentenceComplexityTool"), + ("swarmauri_community.tools.concrete.SentimentAnalysisTool", "SentimentAnalysisTool"), + ("swarmauri_community.tools.concrete.SMOGIndexTool", "SMOGIndexTool"), + ("swarmauri_community.tools.concrete.WebScrapingTool", "WebScrapingTool"), + ("swarmauri_community.tools.concrete.ZapierHookTool", "ZapierHookTool"), + # ("swarmauri_community.tools.concrete.PaCMAPTool", "PaCMAPTool"), +] + +# Lazy loading of tools, storing them in variables +for module_name, class_name in tool_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded tools to __all__ +__all__ = [class_name for _, class_name in tool_files] diff --git a/pkgs/community/swarmauri_community/vector_stores/base/__init__.py b/pkgs/community/swarmauri_community/vector_stores/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/community/swarmauri_community/vector_stores/AnnoyVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/AnnoyVectorStore.py similarity index 99% rename from pkgs/community/swarmauri_community/vector_stores/AnnoyVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/AnnoyVectorStore.py index 354b2c2fe..bcc4fc214 100644 --- a/pkgs/community/swarmauri_community/vector_stores/AnnoyVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/AnnoyVectorStore.py @@ -4,7 +4,7 @@ import os from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/CloudQdrantVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudQdrantVectorStore.py similarity index 99% rename from pkgs/community/swarmauri_community/vector_stores/CloudQdrantVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/CloudQdrantVectorStore.py index e1af20b9a..28f82d250 100644 --- a/pkgs/community/swarmauri_community/vector_stores/CloudQdrantVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudQdrantVectorStore.py @@ -10,7 +10,7 @@ ) from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py similarity index 96% rename from pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py index 9e55fc0ca..770528174 100644 --- a/pkgs/community/swarmauri_community/vector_stores/CloudWeaviateVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/CloudWeaviateVectorStore.py @@ -1,218 +1,218 @@ -from typing import List, Union, Literal, Optional -from pydantic import BaseModel, PrivateAttr -import uuid as ud -import weaviate -from weaviate.classes.init import Auth -from weaviate.util import generate_uuid5 -from weaviate.classes.query import MetadataQuery - -from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding -from swarmauri.vectors.concrete.Vector import Vector - -from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase -from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin -from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin -from swarmauri.vector_stores.base.VectorStoreCloudMixin import VectorStoreCloudMixin - - -class CloudWeaviateVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, VectorStoreCloudMixin): - type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" - - - # Private attributes - _client: Optional[weaviate.Client] = PrivateAttr(default=None) - _embedder: Doc2VecEmbedding = PrivateAttr(default=None) - _namespace_uuid: ud.UUID = PrivateAttr(default_factory=ud.uuid4) - - def __init__(self, **data): - super().__init__(**data) - - # Initialize the vectorizer and Weaviate client - self._embedder = Doc2VecEmbedding(vector_size=self.vector_size) - # self._initialize_client() - - def connect(self, **kwargs): - """ - Initialize the Weaviate client. - """ - if self._client is None: - self._client = weaviate.connect_to_weaviate_cloud( - cluster_url=self.url, - auth_credentials=Auth.api_key(self.api_key), - headers=kwargs.get("headers", {}) - ) - - def disconnect(self) -> None: - """ - Disconnects from the Qdrant cloud vector store. - """ - if self.client is not None: - self.client = None - - def add_document(self, document: Document) -> None: - """ - Add a single document to the vector store. - - :param document: Document to add - """ - try: - collection = self._client.collections.get(self.collection_name) - - # Generate or use existing embedding - embedding = document.embedding or self._embedder.fit_transform([document.content])[0] - - data_object = { - "content": document.content, - "metadata": document.metadata, - } - - # Generate UUID for document - uuid = ( - str(ud.uuid5(self._namespace_uuid, document.id)) - if document.id - else generate_uuid5(data_object) - ) - - collection.data.insert( - properties=data_object, - vector=embedding.value, - uuid=uuid, - ) - - print(f"Document '{document.id}' added to Weaviate.") - except Exception as e: - print(f"Error adding document '{document.id}': {e}") - raise - - def add_documents(self, documents: List[Document]) -> None: - """ - Add multiple documents to the vector store. - - :param documents: List of documents to add - """ - try: - for document in documents: - self.add_document(document) - - print(f"{len(documents)} documents added to Weaviate.") - except Exception as e: - print(f"Error adding documents: {e}") - raise - - def get_document(self, id: str) -> Union[Document, None]: - """ - Retrieve a document by its ID. - - :param id: Document ID - :return: Document object or None if not found - """ - try: - collection = self._client.collections.get(self.collection_name) - - result = collection.query.fetch_object_by_id(ud.uuid5(self._namespace_uuid, id)) - - if result: - return Document( - id=id, - content=result.properties["content"], - metadata=result.properties["metadata"], - ) - return None - except Exception as e: - print(f"Error retrieving document '{id}': {e}") - return None - - def get_all_documents(self) -> List[Document]: - """ - Retrieve all documents from the vector store. - - :return: List of Document objects - """ - try: - collection = self._client.collections.get(self.collection_name) - # return collection - documents = [ - Document( - content=item.properties["content"], - metadata=item.properties["metadata"], - embedding=Vector(value=list(item.vector.values())[0]), - ) - for item in collection.iterator(include_vector=True) - ] - return documents - except Exception as e: - print(f"Error retrieving all documents: {e}") - return [] - - def delete_document(self, id: str) -> None: - """ - Delete a document by its ID. - - :param id: Document ID - """ - try: - collection = self._client.collections.get(self.collection_name) - collection.data.delete_by_id(ud.uuid5(self._namespace_uuid, id)) - print(f"Document '{id}' has been deleted from Weaviate.") - except Exception as e: - print(f"Error deleting document '{id}': {e}") - raise - - def update_document(self, id: str, document: Document) -> None: - """ - Update an existing document. - - :param id: Document ID - :param updated_document: Document object with updated data - """ - self.delete_document(id) - self.add_document(document) - - def retrieve(self, query: str, top_k: int = 5) -> List[Document]: - """ - Retrieve the top_k most relevant documents based on the given query. - - :param query: Query string - :param top_k: Number of top similar documents to retrieve - :return: List of Document objects - """ - try: - collection = self._client.collections.get(self.collection_name) - query_vector = self._embedder.infer_vector(query) - response = collection.query.near_vector( - near_vector=query_vector.value, - limit=top_k, - return_metadata=MetadataQuery(distance=True), - ) - - documents = [ - Document( - # id=res.id, - content=res.properties["content"], - metadata=res.properties["metadata"], - ) - for res in response.objects - ] - return documents - except Exception as e: - print(f"Error retrieving documents for query '{query}': {e}") - return [] - - def close(self): - """ - Close the connection to the Weaviate server. - """ - if self._client: - self._client.close() - - def model_dump_json(self, *args, **kwargs) -> str: - # Call the disconnect method before serialization - self.disconnect() - - # Now proceed with the usual JSON serialization - return super().model_dump_json(*args, **kwargs) - - - def __del__(self): +from typing import List, Union, Literal, Optional +from pydantic import BaseModel, PrivateAttr +import uuid as ud +import weaviate +from weaviate.classes.init import Auth +from weaviate.util import generate_uuid5 +from weaviate.classes.query import MetadataQuery + +from swarmauri.documents.concrete.Document import Document +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri.vectors.concrete.Vector import Vector + +from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase +from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin +from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin +from swarmauri.vector_stores.base.VectorStoreCloudMixin import VectorStoreCloudMixin + + +class CloudWeaviateVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase, VectorStoreCloudMixin): + type: Literal["CloudWeaviateVectorStore"] = "CloudWeaviateVectorStore" + + + # Private attributes + _client: Optional[weaviate.Client] = PrivateAttr(default=None) + _embedder: Doc2VecEmbedding = PrivateAttr(default=None) + _namespace_uuid: ud.UUID = PrivateAttr(default_factory=ud.uuid4) + + def __init__(self, **data): + super().__init__(**data) + + # Initialize the vectorizer and Weaviate client + self._embedder = Doc2VecEmbedding(vector_size=self.vector_size) + # self._initialize_client() + + def connect(self, **kwargs): + """ + Initialize the Weaviate client. + """ + if self._client is None: + self._client = weaviate.connect_to_weaviate_cloud( + cluster_url=self.url, + auth_credentials=Auth.api_key(self.api_key), + headers=kwargs.get("headers", {}) + ) + + def disconnect(self) -> None: + """ + Disconnects from the Qdrant cloud vector store. + """ + if self.client is not None: + self.client = None + + def add_document(self, document: Document) -> None: + """ + Add a single document to the vector store. + + :param document: Document to add + """ + try: + collection = self._client.collections.get(self.collection_name) + + # Generate or use existing embedding + embedding = document.embedding or self._embedder.fit_transform([document.content])[0] + + data_object = { + "content": document.content, + "metadata": document.metadata, + } + + # Generate UUID for document + uuid = ( + str(ud.uuid5(self._namespace_uuid, document.id)) + if document.id + else generate_uuid5(data_object) + ) + + collection.data.insert( + properties=data_object, + vector=embedding.value, + uuid=uuid, + ) + + print(f"Document '{document.id}' added to Weaviate.") + except Exception as e: + print(f"Error adding document '{document.id}': {e}") + raise + + def add_documents(self, documents: List[Document]) -> None: + """ + Add multiple documents to the vector store. + + :param documents: List of documents to add + """ + try: + for document in documents: + self.add_document(document) + + print(f"{len(documents)} documents added to Weaviate.") + except Exception as e: + print(f"Error adding documents: {e}") + raise + + def get_document(self, id: str) -> Union[Document, None]: + """ + Retrieve a document by its ID. + + :param id: Document ID + :return: Document object or None if not found + """ + try: + collection = self._client.collections.get(self.collection_name) + + result = collection.query.fetch_object_by_id(ud.uuid5(self._namespace_uuid, id)) + + if result: + return Document( + id=id, + content=result.properties["content"], + metadata=result.properties["metadata"], + ) + return None + except Exception as e: + print(f"Error retrieving document '{id}': {e}") + return None + + def get_all_documents(self) -> List[Document]: + """ + Retrieve all documents from the vector store. + + :return: List of Document objects + """ + try: + collection = self._client.collections.get(self.collection_name) + # return collection + documents = [ + Document( + content=item.properties["content"], + metadata=item.properties["metadata"], + embedding=Vector(value=list(item.vector.values())[0]), + ) + for item in collection.iterator(include_vector=True) + ] + return documents + except Exception as e: + print(f"Error retrieving all documents: {e}") + return [] + + def delete_document(self, id: str) -> None: + """ + Delete a document by its ID. + + :param id: Document ID + """ + try: + collection = self._client.collections.get(self.collection_name) + collection.data.delete_by_id(ud.uuid5(self._namespace_uuid, id)) + print(f"Document '{id}' has been deleted from Weaviate.") + except Exception as e: + print(f"Error deleting document '{id}': {e}") + raise + + def update_document(self, id: str, document: Document) -> None: + """ + Update an existing document. + + :param id: Document ID + :param updated_document: Document object with updated data + """ + self.delete_document(id) + self.add_document(document) + + def retrieve(self, query: str, top_k: int = 5) -> List[Document]: + """ + Retrieve the top_k most relevant documents based on the given query. + + :param query: Query string + :param top_k: Number of top similar documents to retrieve + :return: List of Document objects + """ + try: + collection = self._client.collections.get(self.collection_name) + query_vector = self._embedder.infer_vector(query) + response = collection.query.near_vector( + near_vector=query_vector.value, + limit=top_k, + return_metadata=MetadataQuery(distance=True), + ) + + documents = [ + Document( + # id=res.id, + content=res.properties["content"], + metadata=res.properties["metadata"], + ) + for res in response.objects + ] + return documents + except Exception as e: + print(f"Error retrieving documents for query '{query}': {e}") + return [] + + def close(self): + """ + Close the connection to the Weaviate server. + """ + if self._client: + self._client.close() + + def model_dump_json(self, *args, **kwargs) -> str: + # Call the disconnect method before serialization + self.disconnect() + + # Now proceed with the usual JSON serialization + return super().model_dump_json(*args, **kwargs) + + + def __del__(self): self.close() \ No newline at end of file diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py similarity index 96% rename from pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py index cc4bace96..8aca2ee07 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/Doc2VecVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/Doc2VecVectorStore.py @@ -1,8 +1,7 @@ from typing import List, Union, Literal -from pydantic import PrivateAttr from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import ( diff --git a/pkgs/community/swarmauri_community/vector_stores/DuckDBVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/DuckDBVectorStore.py similarity index 99% rename from pkgs/community/swarmauri_community/vector_stores/DuckDBVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/DuckDBVectorStore.py index 03d3baf7a..c415c5b6a 100644 --- a/pkgs/community/swarmauri_community/vector_stores/DuckDBVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/DuckDBVectorStore.py @@ -7,7 +7,7 @@ from swarmauri.vectors.concrete.Vector import Vector from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py similarity index 96% rename from pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py index ea5902602..1fd98eed9 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/MlmVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/MlmVectorStore.py @@ -1,11 +1,12 @@ from typing import List, Union, Literal from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.MlmEmbedding import MlmEmbedding +from swarmauri_community.embeddings.concrete.MlmEmbedding import MlmEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin -from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin +from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin + class MlmVectorStore(VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase): type: Literal['MlmVectorStore'] = 'MlmVectorStore' diff --git a/pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/Neo4jVectorStore.py similarity index 99% rename from pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/Neo4jVectorStore.py index 75f162283..a67b1da37 100644 --- a/pkgs/community/swarmauri_community/vector_stores/Neo4jVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/Neo4jVectorStore.py @@ -1,5 +1,5 @@ from typing import List, Union, Literal, Optional -from pydantic import BaseModel, PrivateAttr, field_validator +from pydantic import BaseModel, PrivateAttr from neo4j import GraphDatabase import json diff --git a/pkgs/community/swarmauri_community/vector_stores/PersistentChromaDBVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentChromaDBVectorStore.py similarity index 98% rename from pkgs/community/swarmauri_community/vector_stores/PersistentChromaDBVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PersistentChromaDBVectorStore.py index 83e16a7ae..413a89c25 100644 --- a/pkgs/community/swarmauri_community/vector_stores/PersistentChromaDBVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentChromaDBVectorStore.py @@ -5,7 +5,7 @@ from typing import List, Union, Literal from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/PersistentQdrantVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentQdrantVectorStore.py similarity index 98% rename from pkgs/community/swarmauri_community/vector_stores/PersistentQdrantVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PersistentQdrantVectorStore.py index 2aab53309..e22b7c802 100644 --- a/pkgs/community/swarmauri_community/vector_stores/PersistentQdrantVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/PersistentQdrantVectorStore.py @@ -9,7 +9,7 @@ ) from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/PineconeVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/PineconeVectorStore.py similarity index 99% rename from pkgs/community/swarmauri_community/vector_stores/PineconeVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/PineconeVectorStore.py index a10421a78..d5817aee2 100644 --- a/pkgs/community/swarmauri_community/vector_stores/PineconeVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/PineconeVectorStore.py @@ -5,7 +5,7 @@ from pinecone import ServerlessSpec from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding from swarmauri.distances.concrete.CosineDistance import CosineDistance from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase diff --git a/pkgs/community/swarmauri_community/vector_stores/RedisVectorStore.py b/pkgs/community/swarmauri_community/vector_stores/concrete/RedisVectorStore.py similarity index 96% rename from pkgs/community/swarmauri_community/vector_stores/RedisVectorStore.py rename to pkgs/community/swarmauri_community/vector_stores/concrete/RedisVectorStore.py index 2b77026ac..8264de90f 100644 --- a/pkgs/community/swarmauri_community/vector_stores/RedisVectorStore.py +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/RedisVectorStore.py @@ -1,16 +1,15 @@ import json -from typing import List, Union, Literal, Dict, Optional +from typing import List, Union, Literal, Optional from pydantic import PrivateAttr import numpy as np import redis -from redis.commands.search.field import VectorField, TextField, TagField +from redis.commands.search.field import VectorField, TextField from redis.commands.search.indexDefinition import IndexDefinition, IndexType -from redis.commands.search.query import Query from swarmauri.vectors.concrete.Vector import Vector from swarmauri.documents.concrete.Document import Document -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # or your specific embedder +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding # or your specific embedder from swarmauri.vector_stores.base.VectorStoreBase import VectorStoreBase from swarmauri.vector_stores.base.VectorStoreRetrieveMixin import VectorStoreRetrieveMixin from swarmauri.vector_stores.base.VectorStoreSaveLoadMixin import VectorStoreSaveLoadMixin diff --git a/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py b/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py new file mode 100644 index 000000000..f920d22b4 --- /dev/null +++ b/pkgs/community/swarmauri_community/vector_stores/concrete/__init__.py @@ -0,0 +1,20 @@ +from swarmauri.utils._lazy_import import _lazy_import + +vector_store_files = [ + ("swarmauri_community.vector_stores.concrete.AnnoyVectorStore", "AnnoyVectorStore"), + ("swarmauri_community.vector_stores.concrete.CloudQdrantVectorStore", "CloudQdrantVectorStore"), + ("swarmauri_community.vector_stores.concrete.CloudWeaviateVectorStore", "CloudWeaviateVectorStore"), + ("swarmauri_community.vector_stores.concrete.Doc2VecVectorStore", "Doc2VecVectorStore"), + ("swarmauri_community.vector_stores.concrete.DuckDBVectorStore", "DuckDBVectorStore"), + ("swarmauri_community.vector_stores.concrete.MlmVectorStore", "MlmVectorStore"), + ("swarmauri_community.vector_stores.concrete.Neo4jVectorStore", "Neo4jVectorStore"), + ("swarmauri_community.vector_stores.concrete.PersistentChromaDBVectorStore", "PersistentChromaDBVectorStore"), + ("swarmauri_community.vector_stores.concrete.PersistentQdrantVectorStore", "PersistentQdrantVectorStore"), + ("swarmauri_community.vector_stores.concrete.PineconeVectorStore", "PineconeVectorStore"), + ("swarmauri_community.vector_stores.concrete.RedisVectorStore", "RedisVectorStore"), +] + +for module_name, class_name in vector_store_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +__all__ = [class_name for _, class_name in vector_store_files] diff --git a/pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py b/pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py similarity index 87% rename from pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py rename to pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py index c0dbb3a37..7f3afc447 100644 --- a/pkgs/swarmauri/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py +++ b/pkgs/community/tests/unit/embeddings/Doc2VecEmbedding_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding +from swarmauri_community.embeddings.concrete.Doc2VecEmbedding import Doc2VecEmbedding @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py b/pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py similarity index 87% rename from pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py rename to pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py index 6962bb802..c015aeac3 100644 --- a/pkgs/swarmauri/tests/unit/embeddings/MlmEmbedding_unit_test.py +++ b/pkgs/community/tests/unit/embeddings/MlmEmbedding_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.embeddings.concrete.MlmEmbedding import MlmEmbedding +from swarmauri_community.embeddings.concrete.MlmEmbedding import MlmEmbedding @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py b/pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py similarity index 92% rename from pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py rename to pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py index 6aa6bec95..e5f8a550c 100644 --- a/pkgs/swarmauri/tests/unit/parsers/TextBlobNounParser_unit_test.py +++ b/pkgs/community/tests/unit/parsers/TextBlobNounParser_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.parsers.concrete.TextBlobNounParser import TextBlobNounParser as Parser +from swarmauri_community.parsers.concrete.TextBlobNounParser import TextBlobNounParser as Parser def setup_module(module): diff --git a/pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py b/pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py similarity index 85% rename from pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py rename to pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py index a84023b2f..36c347906 100644 --- a/pkgs/swarmauri/tests/unit/parsers/TextBlobSentenceParser_unit_test.py +++ b/pkgs/community/tests/unit/parsers/TextBlobSentenceParser_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.parsers.concrete.TextBlobSentenceParser import TextBlobSentenceParser as Parser +from swarmauri_community.parsers.concrete.TextBlobSentenceParser import TextBlobSentenceParser as Parser @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/community/tests/unit/state/ClipboardState_test.py b/pkgs/community/tests/unit/state/ClipboardState_test.py new file mode 100644 index 000000000..2b9d200d6 --- /dev/null +++ b/pkgs/community/tests/unit/state/ClipboardState_test.py @@ -0,0 +1,91 @@ +import pytest +import pyperclip +from swarmauri_community.state.concrete.ClipboardState import ClipboardState + + +@pytest.fixture +def clipboard_state(): + """ + Fixture to create a ClipboardState instance and clean up clipboard after tests. + """ + # Store original clipboard content + original_clipboard = pyperclip.paste() + + # Create ClipboardState + state = ClipboardState() + + # Yield the state for tests to use + yield state + + # Restore original clipboard content after tests + pyperclip.copy(original_clipboard) + + +@pytest.mark.unit +def test_ubc_resource(clipboard_state): + """ + Test the resource type of the ClipboardState. + """ + assert clipboard_state.resource == "State" + + +@pytest.mark.unit +def test_write_and_read(clipboard_state): + """ + Test writing data to clipboard and reading it back. + """ + test_data = {"key1": "value1", "key2": 42} + clipboard_state.write(test_data) + read_data = clipboard_state.read() + assert read_data == test_data + + +@pytest.mark.unit +def test_update(clipboard_state): + """ + Test updating existing clipboard data. + """ + # Initial write + initial_data = {"existing_key": "existing_value"} + clipboard_state.write(initial_data) + + # Update with new data + update_data = {"new_key": "new_value"} + clipboard_state.update(update_data) + + # Read and verify merged data + read_data = clipboard_state.read() + expected_data = {"existing_key": "existing_value", "new_key": "new_value"} + assert read_data == expected_data + + +@pytest.mark.unit +def test_reset(clipboard_state): + """ + Test resetting the clipboard state to an empty dictionary. + """ + # Write some data + clipboard_state.write({"some_key": "some_value"}) + + # Reset + clipboard_state.reset() + + # Verify empty state + assert clipboard_state.read() == {} + + +@pytest.mark.unit +def test_deep_copy(clipboard_state): + """ + Test creating a deep copy of the clipboard state. + """ + # Write initial data + initial_data = {"key1": "value1", "key2": "value2"} + clipboard_state.write(initial_data) + + # Create deep copy + copied_state = clipboard_state.deep_copy() + + # Verify copied state + assert isinstance(copied_state, ClipboardState) + assert copied_state.read() == initial_data diff --git a/pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py b/pkgs/community/tests/unit/tools/TextLength_unit_test.py similarity index 96% rename from pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py rename to pkgs/community/tests/unit/tools/TextLength_unit_test.py index 7bc0c94a1..72d3ff6bc 100644 --- a/pkgs/swarmauri/tests/unit/tools/TextLength_unit_test.py +++ b/pkgs/community/tests/unit/tools/TextLength_unit_test.py @@ -1,5 +1,5 @@ import pytest -from swarmauri.tools.concrete import TextLengthTool as Tool +from swarmauri_community.tools.concrete import TextLengthTool as Tool @pytest.mark.unit def test_ubc_resource(): diff --git a/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py index 9f2ce8320..cee7afddd 100644 --- a/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/AnnoyVectorStore_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.AnnoyVectorStore import AnnoyVectorStore +from swarmauri_community.vector_stores.concrete.AnnoyVectorStore import AnnoyVectorStore # Fixture for creating an AnnoyVectorStore instance diff --git a/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py index 2b9028c72..26ee25841 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudQdrantVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.CloudQdrantVectorStore import ( +from swarmauri_community.vector_stores.concrete.CloudQdrantVectorStore import ( CloudQdrantVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py index 051d6aa6b..9bad53191 100644 --- a/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/CloudWeaviateVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.CloudWeaviateVectorStore import ( +from swarmauri_community.vector_stores.concrete.CloudWeaviateVectorStore import ( CloudWeaviateVectorStore, ) from dotenv import load_dotenv diff --git a/pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py similarity index 95% rename from pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py rename to pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py index 7afcd7097..497e8a45f 100644 --- a/pkgs/swarmauri/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/Doc2VecVectorStore_unit_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore +from swarmauri_community.vector_stores.concrete.Doc2VecVectorStore import Doc2VecVectorStore @pytest.mark.unit diff --git a/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py index 28bd33080..0b247ccd8 100644 --- a/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/DuckDBVectorStore_unit_test.py @@ -2,7 +2,7 @@ import os import json from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.DuckDBVectorStore import DuckDBVectorStore +from swarmauri_community.vector_stores.concrete.DuckDBVectorStore import DuckDBVectorStore @pytest.fixture(params=[":memory:", "test_db.db"]) diff --git a/pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py b/pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py similarity index 90% rename from pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py rename to pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py index 06b0fa263..1c3dc9273 100644 --- a/pkgs/swarmauri/tests/unit/vector_stores/MlmVectorStore_unit_test.py +++ b/pkgs/community/tests/unit/vector_stores/MlmVectorStore_unit_test.py @@ -1,6 +1,6 @@ import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri.vector_stores.concrete.MlmVectorStore import MlmVectorStore +from swarmauri_community.vector_stores.concrete.MlmVectorStore import MlmVectorStore @pytest.mark.unit diff --git a/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py index 74de4851b..d5e4699f6 100644 --- a/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/Neo4jVectorStore_test.py @@ -2,7 +2,7 @@ import pytest from dotenv import load_dotenv from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.Neo4jVectorStore import Neo4jVectorStore +from swarmauri_community.vector_stores.concrete.Neo4jVectorStore import Neo4jVectorStore # Load environment variables load_dotenv() diff --git a/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py index b973277f1..66e0f0ed3 100644 --- a/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PersistentChromadbVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PersistentChromaDBVectorStore import ( +from swarmauri_community.vector_stores.concrete.PersistentChromaDBVectorStore import ( PersistentChromaDBVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py index 2277a87a4..d58c4295f 100644 --- a/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PersistentQdrantVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PersistentQdrantVectorStore import ( +from swarmauri_community.vector_stores.concrete.PersistentQdrantVectorStore import ( PersistentQdrantVectorStore, ) diff --git a/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py index de2749df0..803a3f61c 100644 --- a/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/PineconeVectorStore_test.py @@ -1,7 +1,7 @@ import os import pytest from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.PineconeVectorStore import PineconeVectorStore +from swarmauri_community.vector_stores.concrete.PineconeVectorStore import PineconeVectorStore from dotenv import load_dotenv load_dotenv() diff --git a/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py b/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py index 17cbce8a2..80fc1a2d5 100644 --- a/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py +++ b/pkgs/community/tests/unit/vector_stores/RedisVectorStore_test.py @@ -1,8 +1,6 @@ import pytest -import numpy as np -from swarmauri.documents.concrete.Document import Document -from swarmauri_community.vector_stores.RedisVectorStore import RedisVectorStore from swarmauri.documents.concrete.Document import Document +from swarmauri_community.vector_stores.concrete.RedisVectorStore import RedisVectorStore from dotenv import load_dotenv from os import getenv diff --git a/pkgs/core/pyproject.toml b/pkgs/core/pyproject.toml index 4f5283c21..5558e0cd5 100644 --- a/pkgs/core/pyproject.toml +++ b/pkgs/core/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-core" -version = "0.5.2" +version = "0.5.3.dev3" description = "This repository includes core interfaces for the Swarmauri framework." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -25,6 +25,7 @@ flake8 = "^7.0" # Add flake8 as a development dependency pytest = "^8.0" # Ensure pytest is also added if you run tests pytest-asyncio = ">=0.24.0" pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/pkgs/core/swarmauri_core/ComponentBase.py b/pkgs/core/swarmauri_core/ComponentBase.py index 9f3e86e9f..cbf5b770d 100644 --- a/pkgs/core/swarmauri_core/ComponentBase.py +++ b/pkgs/core/swarmauri_core/ComponentBase.py @@ -1,12 +1,13 @@ +import json from typing import ( + Any, + Dict, Optional, List, Literal, TypeVar, Type, Union, - Annotated, - Generic, ClassVar, Set, get_args, @@ -16,11 +17,14 @@ from enum import Enum import inspect import hashlib -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator import logging from swarmauri_core.typing import SubclassUnion +T = TypeVar("T", bound="ComponentBase") + + class ResourceTypes(Enum): UNIVERSAL_BASE = "ComponentBase" AGENT = "Agent" @@ -34,6 +38,7 @@ class ResourceTypes(Enum): DOCUMENT = "Document" EMBEDDING = "Embedding" EXCEPTION = "Exception" + IMAGE_GEN = "ImageGen" LLM = "LLM" MESSAGE = "Message" MEASUREMENT = "Measurement" @@ -51,12 +56,17 @@ class ResourceTypes(Enum): VECTOR_STORE = "VectorStore" VECTOR = "Vector" VCM = "VCM" - + DATA_CONNECTOR = "DataConnector" + TRANSPORT = "Transport" + FACTORY = "Factory" + PIPELINE = "Pipeline" + SERVICE_REGISTRY = "ServiceRegistry" + CONTROL_PANEL = "ControlPanel" + TASK_MGT_STRATEGY = "TaskMgtStrategy" def generate_id() -> str: return str(uuid4()) - class ComponentBase(BaseModel): name: Optional[str] = None id: str = Field(default_factory=generate_id) @@ -87,9 +97,6 @@ def __swm_register_subclass__(cls, subclass) -> None: f"Subclass {subclass.__name__} does not have a type annotation" ) - # [subclass.__swm_reset_class__() for subclass in cls.__swm_subclasses__ - # if hasattr(subclass, '__swm_reset_class__')] - @classmethod def __swm_reset_class__(cls): logging.debug("__swm_reset_class__ executed\n") @@ -122,9 +129,6 @@ def __swm_reset_class__(cls): cls.__annotations__[each] = sc cls.__fields__[each].annotation = sc - # This is not necessary as the model_rebuild address forward_refs - # https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_post_init - # cls.update_forward_refs() cls.model_rebuild(force=True) @field_validator("type") @@ -178,3 +182,62 @@ def swm_path(self): @property def swm_isremote(self): return bool(self.host) + + @classmethod + def model_validate_json( + cls: Type[T], json_payload: Union[str, Dict[str, Any]], strict: bool = False + ) -> T: + # Ensure we're working with a dictionary + if isinstance(json_payload, str): + try: + payload_dict = json.loads(json_payload) + except json.JSONDecodeError: + raise ValueError("Invalid JSON payload") + else: + payload_dict = json_payload + + # Try to determine the specific component type + component_type = payload_dict.get("type", "ComponentBase") + + # Attempt to find the correct subclass + target_cls = cls.get_subclass_by_type(component_type) + + # Fallback logic + if target_cls is None: + if strict: + raise ValueError(f"Cannot resolve component type: {component_type}") + target_cls = cls + logging.warning( + f"Falling back to base ComponentBase for type: {component_type}" + ) + + # Validate using the determined class + try: + return target_cls.model_validate(payload_dict) + except ValidationError as e: + logging.error(f"Validation failed for {component_type}: {e}") + raise + + @classmethod + def get_subclass_by_type(cls, type_name: str) -> Optional[Type["ComponentBase"]]: + # First, check for exact match in registered subclasses + for subclass in cls.__swm_subclasses__: + if ( + subclass.__name__ == type_name + or getattr(subclass, "type", None) == type_name + ): + return subclass + + # If no exact match, try case-insensitive search + for subclass in cls.__swm_subclasses__: + if ( + subclass.__name__.lower() == type_name.lower() + or str(getattr(subclass, "type", "")).lower() == type_name.lower() + ): + return subclass + + return None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ComponentBase": + return cls.model_validate(data) diff --git a/pkgs/core/swarmauri_core/agent_factories/IAgentFactory.py b/pkgs/core/swarmauri_core/agent_factories/IAgentFactory.py deleted file mode 100644 index 8dc24b3d8..000000000 --- a/pkgs/core/swarmauri_core/agent_factories/IAgentFactory.py +++ /dev/null @@ -1,79 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Type, Any -from datetime import datetime - -class IAgentFactory(ABC): - """ - Interface for Agent Factories, extended to include properties like ID, name, type, - creation date, and last modification date. - """ - - @abstractmethod - def create_agent(self, agent_type: str, **kwargs) -> Any: - pass - - @abstractmethod - def register_agent(self, agent_type: str, constructor: Type[Any]) -> None: - pass - - # Abstract properties and setters - @property - @abstractmethod - def id(self) -> str: - """Unique identifier for the factory instance.""" - pass - - @id.setter - @abstractmethod - def id(self, value: str) -> None: - pass - - @property - @abstractmethod - def name(self) -> str: - """Name of the factory.""" - pass - - @name.setter - @abstractmethod - def name(self, value: str) -> None: - pass - - @property - @abstractmethod - def type(self) -> str: - """Type of agents this factory produces.""" - pass - - @type.setter - @abstractmethod - def type(self, value: str) -> None: - pass - - @property - @abstractmethod - def date_created(self) -> datetime: - """The creation date of the factory instance.""" - pass - - @property - @abstractmethod - def last_modified(self) -> datetime: - """Date when the factory was last modified.""" - pass - - @last_modified.setter - @abstractmethod - def last_modified(self, value: datetime) -> None: - pass - - def __hash__(self): - """ - The __hash__ method allows objects of this class to be used in sets and as dictionary keys. - __hash__ should return an integer and be defined based on immutable properties. - This is generally implemented directly in concrete classes rather than in the interface, - but it's declared here to indicate that implementing classes must provide it. - """ - pass - - \ No newline at end of file diff --git a/pkgs/core/swarmauri_core/control_panels/IControlPanel.py b/pkgs/core/swarmauri_core/control_panels/IControlPanel.py new file mode 100644 index 000000000..d1810b023 --- /dev/null +++ b/pkgs/core/swarmauri_core/control_panels/IControlPanel.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Any, List + + +class IControlPlane(ABC): + """ + Abstract base class for ControlPlane. + """ + + @abstractmethod + def create_agent(self, name: str, role: str) -> Any: + """ + Create an agent with the given name and role. + """ + pass + + @abstractmethod + def remove_agent(self, name: str) -> None: + """ + Remove the agent with the specified name. + """ + pass + + @abstractmethod + def list_active_agents(self) -> List[str]: + """ + List all active agent names. + """ + pass + + @abstractmethod + def submit_tasks(self, tasks: List[Any]) -> None: + """ + Submit one or more tasks to the task management strategy for processing. + """ + pass + + @abstractmethod + def process_tasks(self) -> None: + """ + Process and assign tasks from the queue, then transport them to their assigned services. + """ + pass + + @abstractmethod + def distribute_tasks(self, task: Any) -> None: + """ + Distribute tasks using the task strategy. + """ + pass + + @abstractmethod + def orchestrate_agents(self, task: Any) -> None: + """ + Orchestrate agents for task distribution. + """ + pass diff --git a/pkgs/core/swarmauri_core/control_panels/__init__.py b/pkgs/core/swarmauri_core/control_panels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/dataconnectors/IDataConnector.py b/pkgs/core/swarmauri_core/dataconnectors/IDataConnector.py new file mode 100644 index 000000000..b0c7826c4 --- /dev/null +++ b/pkgs/core/swarmauri_core/dataconnectors/IDataConnector.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod + + +class IDataConnector(ABC): + """ + Abstract base class for data connectors. + Defines the interface for all concrete data connector implementations. + """ + + @abstractmethod + def authenticate(self, **kwargs): + """ + Authenticate with the data source. + This method should handle any required authentication process. + + :param kwargs: Authentication parameters such as API keys, tokens, etc. + """ + pass + + @abstractmethod + def fetch_data(self, query: str, **kwargs): + """ + Fetch data from the data source based on a query. + + :param query: Query string or parameters to fetch specific data. + :param kwargs: Additional parameters for fetching data. + :return: Data fetched from the source. + """ + pass + + @abstractmethod + def insert_data(self, data, **kwargs): + """ + Insert data into the data source. + + :param data: Data to be inserted. + :param kwargs: Additional parameters for inserting data. + """ + pass + + @abstractmethod + def update_data(self, identifier, data, **kwargs): + """ + Update existing data in the data source. + + :param identifier: Unique identifier of the data to update. + :param data: Updated data. + :param kwargs: Additional parameters for updating data. + """ + pass + + @abstractmethod + def delete_data(self, identifier, **kwargs): + """ + Delete data from the data source. + + :param identifier: Unique identifier of the data to delete. + :param kwargs: Additional parameters for deleting data. + """ + pass + + @abstractmethod + def test_connection(self, **kwargs): + """ + Test the connection to the data source. + + :param kwargs: Connection parameters. + :return: Boolean indicating whether the connection is successful. + """ + pass diff --git a/pkgs/core/swarmauri_core/dataconnectors/__init__.py b/pkgs/core/swarmauri_core/dataconnectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/factories/IFactory.py b/pkgs/core/swarmauri_core/factories/IFactory.py new file mode 100644 index 000000000..a30fc7d15 --- /dev/null +++ b/pkgs/core/swarmauri_core/factories/IFactory.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable + + +class IFactory(ABC): + """ + Interface defining core methods for factories. + """ + + @abstractmethod + def create(self, type: str, *args: Any, **kwargs: Any) -> Any: + """Create and return an instance.""" + pass + + @abstractmethod + def register(self, type: str, resource_class: Callable) -> None: + """Register a class with the factory.""" + pass diff --git a/pkgs/core/swarmauri_core/factories/__init__.py b/pkgs/core/swarmauri_core/factories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/image_gens/IGenImage.py b/pkgs/core/swarmauri_core/image_gens/IGenImage.py new file mode 100644 index 000000000..79cebf615 --- /dev/null +++ b/pkgs/core/swarmauri_core/image_gens/IGenImage.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod + + +class IGenImage(ABC): + """ + Interface focusing on the basic properties and settings essential for defining image generating models. + """ + + @abstractmethod + def generate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + async def agenerate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + def batch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass + + @abstractmethod + async def abatch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + pass diff --git a/pkgs/core/swarmauri_core/image_gens/__init__.py b/pkgs/core/swarmauri_core/image_gens/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/pipelines/IPipeline.py b/pkgs/core/swarmauri_core/pipelines/IPipeline.py new file mode 100644 index 000000000..9e941fdd9 --- /dev/null +++ b/pkgs/core/swarmauri_core/pipelines/IPipeline.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, List +from enum import Enum + + +class PipelineStatus(Enum): + """ + Enum representing the status of a pipeline execution. + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + STOPPED = "stopped" + + +class IPipeline(ABC): + """ + Interface defining core methods for pipeline execution and management. + """ + + @abstractmethod + def add_task(self, task: Callable, *args: Any, **kwargs: Any) -> None: + """ + Add a task to the pipeline. + + :param task: Callable task to be executed + :param args: Positional arguments for the task + :param kwargs: Keyword arguments for the task + """ + pass + + @abstractmethod + def execute(self, *args: Any, **kwargs: Any) -> List[Any]: + """ + Execute the entire pipeline. + + :return: List of results from pipeline execution + """ + pass + + @abstractmethod + def get_status(self) -> PipelineStatus: + """ + Get the current status of the pipeline. + + :return: Current pipeline status + """ + pass + + @abstractmethod + def reset(self) -> None: + """ + Reset the pipeline to its initial state. + """ + pass diff --git a/pkgs/core/swarmauri_core/pipelines/__init__.py b/pkgs/core/swarmauri_core/pipelines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/service_registries/IServiceRegistry.py b/pkgs/core/swarmauri_core/service_registries/IServiceRegistry.py new file mode 100644 index 000000000..911dc7b19 --- /dev/null +++ b/pkgs/core/swarmauri_core/service_registries/IServiceRegistry.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional + + +class IServiceRegistry(ABC): + """ + Abstract base class for service registries. + """ + + @abstractmethod + def register_service(self, name: str, details: Dict[str, Any]) -> None: + """ + Register a new service with the given name and details. + """ + pass + + @abstractmethod + def get_service(self, name: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a service by its name. + """ + pass + + @abstractmethod + def get_services_by_roles(self, roles: List[str]) -> List[str]: + """ + Get services filtered by their roles. + """ + pass + + @abstractmethod + def unregister_service(self, name: str) -> None: + """ + unregister the service with the given name. + """ + pass + + @abstractmethod + def update_service(self, name: str, details: Dict[str, Any]) -> None: + """ + Update the details of the service with the given name. + """ + pass diff --git a/pkgs/core/swarmauri_core/service_registries/__init__.py b/pkgs/core/swarmauri_core/service_registries/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/state/IState.py b/pkgs/core/swarmauri_core/state/IState.py new file mode 100644 index 000000000..02c571bea --- /dev/null +++ b/pkgs/core/swarmauri_core/state/IState.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any + + +class IState(ABC): + @abstractmethod + def read(self) -> Dict[str, Any]: + """ + Reads and returns the current state as a dictionary. + """ + + @abstractmethod + def write(self, data: Dict[str, Any]) -> None: + """ + Replaces the current state with the given data. + """ + + @abstractmethod + def update(self, data: Dict[str, Any]) -> None: + """ + Updates the state with the given data. + """ + + @abstractmethod + def reset(self) -> None: + """ + Resets the state to its initial state. + """ diff --git a/pkgs/core/swarmauri_core/swarms/ISwarm.py b/pkgs/core/swarmauri_core/swarms/ISwarm.py index 7d3ecb8d8..7f3e0bb15 100644 --- a/pkgs/core/swarmauri_core/swarms/ISwarm.py +++ b/pkgs/core/swarmauri_core/swarms/ISwarm.py @@ -1,67 +1,32 @@ from abc import ABC, abstractmethod -from typing import Any, List, Dict -from datetime import datetime -from swarmauri_core.agents.IAgent import IAgent -from swarmauri_core.chains.ICallableChain import ICallableChain +from typing import Any, Dict, List, Optional, Union -class ISwarm(ABC): - """ - Interface for a Swarm, representing a collective of agents capable of performing tasks, executing callable chains, and adaptable configurations. - """ - - # Abstract properties and setters - @property - @abstractmethod - def id(self) -> str: - """Unique identifier for the factory instance.""" - pass - - @id.setter - @abstractmethod - def id(self, value: str) -> None: - pass - - @property - @abstractmethod - def name(self) -> str: - pass - @name.setter - @abstractmethod - def name(self, value: str) -> None: - pass +class ISwarm(ABC): + """Abstract base class for swarm implementations""" - @property @abstractmethod - def type(self) -> str: + async def exec( + self, + input_data: Union[str, List[str]], + **kwargs: Dict[str, Any], + ) -> Any: + """Execute swarm tasks with given input""" pass - @type.setter @abstractmethod - def type(self, value: str) -> None: + def get_swarm_status(self) -> Dict[int, Any]: + """Get status of all agents in the swarm""" pass @property @abstractmethod - def date_created(self) -> datetime: + def agents(self) -> List[Any]: + """Get list of agents in the swarm""" pass @property @abstractmethod - def last_modified(self) -> datetime: + def queue_size(self) -> int: + """Get size of task queue""" pass - - @last_modified.setter - @abstractmethod - def last_modified(self, value: datetime) -> None: - pass - - def __hash__(self): - """ - The __hash__ method allows objects of this class to be used in sets and as dictionary keys. - __hash__ should return an integer and be defined based on immutable properties. - This is generally implemented directly in concrete classes rather than in the interface, - but it's declared here to indicate that implementing classes must provide it. - """ - pass - diff --git a/pkgs/core/swarmauri_core/swarms/ISwarmAgentRegistration.py b/pkgs/core/swarmauri_core/swarms/ISwarmAgentRegistration.py deleted file mode 100644 index a32dc235d..000000000 --- a/pkgs/core/swarmauri_core/swarms/ISwarmAgentRegistration.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Optional -from swarmauri_core.agents.IAgent import IAgent - -class ISwarmAgentRegistration(ABC): - """ - Interface for registering agents with the swarm, designed to support CRUD operations on IAgent instances. - """ - - @id.setter - @abstractmethod - def registry(self, value: str) -> None: - pass - - @property - @abstractmethod - def registry(self) -> List[IAgent]: - pass - - @abstractmethod - def register_agent(self, agent: IAgent) -> bool: - """ - Register a new agent with the swarm. - - Parameters: - agent (IAgent): An instance of IAgent representing the agent to register. - - Returns: - bool: True if the registration succeeded; False otherwise. - """ - pass - - @abstractmethod - def update_agent(self, agent_id: str, updated_agent: IAgent) -> bool: - """ - Update the details of an existing agent. This could include changing the agent's configuration, - task assignment, or any other mutable attribute. - - Parameters: - agent_id (str): The unique identifier for the agent. - updated_agent (IAgent): An updated IAgent instance to replace the existing one. - - Returns: - bool: True if the update was successful; False otherwise. - """ - pass - - @abstractmethod - def remove_agent(self, agent_id: str) -> bool: - """ - Remove an agent from the swarm based on its unique identifier. - - Parameters: - agent_id (str): The unique identifier for the agent to be removed. - - Returns: - bool: True if the removal was successful; False otherwise. - """ - pass - - @abstractmethod - def get_agent(self, agent_id: str) -> Optional[IAgent]: - """ - Retrieve an agent's instance from its unique identifier. - - Parameters: - agent_id (str): The unique identifier for the agent of interest. - - Returns: - Optional[IAgent]: The IAgent instance if found; None otherwise. - """ - pass - diff --git a/pkgs/core/swarmauri_core/swarms/ISwarmChainCRUD.py b/pkgs/core/swarmauri_core/swarms/ISwarmChainCRUD.py deleted file mode 100644 index dcb0cf191..000000000 --- a/pkgs/core/swarmauri_core/swarms/ISwarmChainCRUD.py +++ /dev/null @@ -1,62 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Any - -class ISwarmChainCRUD(ABC): - """ - Interface to provide CRUD operations for ICallableChain within swarms. - """ - - @abstractmethod - def create_chain(self, chain_id: str, chain_definition: Dict[str, Any]) -> None: - """ - Creates a callable chain with the provided definition. - - Parameters: - - chain_id (str): A unique identifier for the callable chain. - - chain_definition (Dict[str, Any]): The definition of the callable chain including steps and their configurations. - """ - pass - - @abstractmethod - def read_chain(self, chain_id: str) -> Dict[str, Any]: - """ - Retrieves the definition of a callable chain by its identifier. - - Parameters: - - chain_id (str): The unique identifier of the callable chain to be retrieved. - - Returns: - - Dict[str, Any]: The definition of the callable chain. - """ - pass - - @abstractmethod - def update_chain(self, chain_id: str, new_definition: Dict[str, Any]) -> None: - """ - Updates an existing callable chain with a new definition. - - Parameters: - - chain_id (str): The unique identifier of the callable chain to be updated. - - new_definition (Dict[str, Any]): The new definition of the callable chain including updated steps and configurations. - """ - pass - - @abstractmethod - def delete_chain(self, chain_id: str) -> None: - """ - Removes a callable chain from the swarm. - - Parameters: - - chain_id (str): The unique identifier of the callable chain to be removed. - """ - pass - - @abstractmethod - def list_chains(self) -> List[Dict[str, Any]]: - """ - Lists all callable chains currently managed by the swarm. - - Returns: - - List[Dict[str, Any]]: A list of callable chain definitions. - """ - pass \ No newline at end of file diff --git a/pkgs/core/swarmauri_core/swarms/ISwarmComponent.py b/pkgs/core/swarmauri_core/swarms/ISwarmComponent.py deleted file mode 100644 index aee78b027..000000000 --- a/pkgs/core/swarmauri_core/swarms/ISwarmComponent.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod - -class ISwarmComponent(ABC): - """ - Interface for defining a general component within a swarm system. - """ - - @abstractmethod - def __init__(self, key: str, name: str): - """ - Initializes a swarm component with a unique key and name. - """ - pass \ No newline at end of file diff --git a/pkgs/core/swarmauri_core/swarms/ISwarmConfigurationExporter.py b/pkgs/core/swarmauri_core/swarms/ISwarmConfigurationExporter.py deleted file mode 100644 index 3b1672ca3..000000000 --- a/pkgs/core/swarmauri_core/swarms/ISwarmConfigurationExporter.py +++ /dev/null @@ -1,33 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict -class ISwarmConfigurationExporter(ABC): - - @abstractmethod - def to_dict(self) -> Dict: - """ - Serializes the swarm configuration to a dictionary. - - Returns: - Dict: The serialized configuration as a dictionary. - """ - pass - - @abstractmethod - def to_json(self) -> str: - """ - Serializes the swarm configuration to a JSON string. - - Returns: - str: The serialized configuration as a JSON string. - """ - pass - - @abstractmethod - def to_pickle(self) -> bytes: - """ - Serializes the swarm configuration to a Pickle byte stream. - - Returns: - bytes: The serialized configuration as a Pickle byte stream. - """ - pass \ No newline at end of file diff --git a/pkgs/core/swarmauri_core/swarms/ISwarmFactory.py b/pkgs/core/swarmauri_core/swarms/ISwarmFactory.py deleted file mode 100644 index d26416114..000000000 --- a/pkgs/core/swarmauri_core/swarms/ISwarmFactory.py +++ /dev/null @@ -1,120 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union -from swarmauri_core.swarms.ISwarm import ISwarm -from swarmauri_core.chains.ICallableChain import ICallableChain -from swarmauri_core.agents.IAgent import IAgent - -class Step(NamedTuple): - description: str - callable: Callable # Reference to the function to execute - args: Optional[List[Any]] = None - kwargs: Optional[Dict[str, Any]] = None - -class CallableChainItem(NamedTuple): - key: str # Unique identifier for the item within the chain - execution_context: Dict[str, Any] # Execution context and metadata - steps: List[Step] - -class AgentDefinition(NamedTuple): - type: str - configuration: Dict[str, Any] - capabilities: List[str] - dependencies: List[str] - execution_context: Dict[str, Any] - -class FunctionParameter(NamedTuple): - name: str - type: Type - default: Optional[Any] = None - required: bool = True - -class FunctionDefinition(NamedTuple): - identifier: str - parameters: List[FunctionParameter] - return_type: Type - execution_context: Dict[str, Any] - callable_source: Callable - -class ISwarmFactory(ABC): - - @abstractmethod - def create_swarm(self, *args, **kwargs) -> ISwarm: - """ - Creates and returns a new swarm instance configured with the provided arguments. - """ - pass - - @abstractmethod - def create_agent(self, agent_definition: AgentDefinition) -> IAgent: - """ - Creates a new agent based on the provided enhanced agent definition. - - Args: - agent_definition: An instance of AgentDefinition detailing the agent's setup. - - Returns: - An instance or identifier of the newly created agent. - """ - pass - - @abstractmethod - def create_callable_chain(self, chain_definition: List[CallableChainItem]) -> ICallableChain: - """ - Creates a new callable chain based on the provided definition. - - Args: - chain_definition: Details required to build the chain, such as sequence of functions and arguments. - - Returns: - ICallableChain: The constructed callable chain instance. - """ - pass - - @abstractmethod - def register_function(self, function_definition: FunctionDefinition) -> None: - """ - Registers a function within the factory ecosystem, making it available for callable chains and agents. - - Args: - function_definition: An instance of FunctionDefinition detailing the function's specification. - """ - pass - - @abstractmethod - def export_callable_chains(self, format_type: str = 'json') -> Union[dict, str, bytes]: - """ - Exports configurations of all callable chains in the specified format. - Supported formats: 'json', 'pickle'. - - Args: - format_type (str): The format for exporting the configurations. - - Returns: - Union[dict, str, bytes]: The callable chain configurations in the specified format. - """ - pass - - @abstractmethod - def load_callable_chains(self, chains_data, format_type: str = 'json'): - """ - Loads callable chain configurations from given data. - - Args: - chains_data (Union[dict, str, bytes]): Data containing callable chain configurations. - format_type (str): The format of the provided chains data. - """ - pass - - @abstractmethod - def export_configuration(self, format_type: str = 'json') -> Union[dict, str, bytes]: - """ - Exports the swarm's and agents' configurations in the specified format. - Supported formats: 'json', 'pickle'. Default is 'json'. - - Args: - format_type (str): The format for exporting the configurations. - - Returns: - Union[dict, str, bytes]: The configurations in the specified format. - """ - pass diff --git a/pkgs/core/swarmauri_core/task_mgt_strategies/ITaskMgtStrategy.py b/pkgs/core/swarmauri_core/task_mgt_strategies/ITaskMgtStrategy.py new file mode 100644 index 000000000..13be06917 --- /dev/null +++ b/pkgs/core/swarmauri_core/task_mgt_strategies/ITaskMgtStrategy.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict + + +class ITaskMgtStrategy(ABC): + """Abstract base class for TaskStrategy.""" + + @abstractmethod + def assign_task( + self, task: Dict[str, Any], agent_factory: Callable, service_registry: Callable + ) -> str: + """ + Abstract method to assign a task to a service. + """ + pass + + @abstractmethod + def add_task(self, task: Dict[str, Any]) -> None: + """ + Abstract method to add a task to the task queue. + """ + pass + + @abstractmethod + def remove_task(self, task_id: str) -> None: + """ + Abstract method to remove a task from the task queue. + """ + pass + + @abstractmethod + def get_task(self, task_id: str) -> Dict[str, Any]: + """ + Abstract method to get a task from the task queue. + """ + pass + + @abstractmethod + def process_tasks(self, task: Dict[str, Any]) -> None: + """ + Abstract method to process a task. + """ + pass diff --git a/pkgs/core/swarmauri_core/task_mgt_strategies/__init__.py b/pkgs/core/swarmauri_core/task_mgt_strategies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/core/swarmauri_core/transports/ITransport.py b/pkgs/core/swarmauri_core/transports/ITransport.py new file mode 100644 index 000000000..c9f811473 --- /dev/null +++ b/pkgs/core/swarmauri_core/transports/ITransport.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Any, List + + +class ITransport(ABC): + """ + Interface defining standard transportation methods for agent interactions + """ + + @abstractmethod + def send(self, sender: str, recipient: str, message: Any) -> None: + """ + Send a message to a specific recipient + """ + pass + + @abstractmethod + def broadcast(self, sender: str, message: Any) -> None: + """ + Broadcast a message to all agents + """ + pass + + @abstractmethod + def multicast(self, sender: str, recipients: List[str], message: Any) -> None: + """ + Send a message to multiple specific recipients + """ + pass diff --git a/pkgs/core/swarmauri_core/transports/__init__.py b/pkgs/core/swarmauri_core/transports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/experimental/pyproject.toml b/pkgs/experimental/pyproject.toml index bea605df3..a9bef7a00 100644 --- a/pkgs/experimental/pyproject.toml +++ b/pkgs/experimental/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri-experimental" -version = "0.5.2" +version = "0.5.3.dev5" description = "This repository includes experimental components." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -14,8 +14,8 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10,<4.0" -swarmauri = "==0.5.2" +python = ">=3.10,<3.13" +swarmauri = "==0.5.3.dev5" gensim = "*" neo4j = "*" numpy = "*" @@ -31,6 +31,8 @@ typing_extensions = "*" flake8 = "^7.0" # Add flake8 as a development dependency pytest = "^8.0" # Ensure pytest is also added if you run tests pytest-asyncio = ">=0.24.0" +pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/pkgs/swarmauri/pyproject.toml b/pkgs/swarmauri/pyproject.toml index 7ba365c89..18ea3f273 100644 --- a/pkgs/swarmauri/pyproject.toml +++ b/pkgs/swarmauri/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swarmauri" -version = "0.5.2" +version = "0.5.3.dev5" description = "This repository includes base classes, concrete generics, and concrete standard components within the Swarmauri framework." authors = ["Jacob Stewart "] license = "Apache-2.0" @@ -14,17 +14,26 @@ classifiers = [ ] [tool.poetry.dependencies] +# Python Version python = ">=3.10,<3.13" -swarmauri_core = "==0.5.2" + +# Swarmauri +swarmauri_core = { path = "../core" } + +# Dependencies toml = "^0.10.2" -httpx = "^0.27.2" +httpx = "^0.27.0" joblib = "^1.4.0" numpy = "*" pandas = "*" pydantic = "^2.9.2" -Pillow = ">=8.0,<11.0" typing_extensions = "*" -requests = "*" + +# We should remove and only rely on httpx +requests = "^2.32.3" + +# This should be set to optional also +Pillow = ">=8.0,<11.0" # Optional dependencies with versions specified aiofiles = { version = "24.1.0", optional = true } @@ -34,44 +43,51 @@ aiohttp = { version = "^3.10.10", optional = true } #fal-client = { version = ">=0.5.0", optional = true } #google-generativeai = { version = "^0.8.3", optional = true } #openai = { version = "^1.52.0", optional = true } -nltk = { version = "^3.9.1", optional = true } -textblob = { version = "^0.18.0", optional = true } +#nltk = { version = "^3.9.1", optional = true } +#textblob = { version = "^0.18.0", optional = true } yake = { version = "==0.4.8", optional = true } beautifulsoup4 = { version = "04.12.3", optional = true } -gensim = { version = "==4.3.3", optional = true } +#gensim = { version = "==4.3.3", optional = true } scipy = { version = ">=1.7.0,<1.14.0", optional = true } scikit-learn = { version = "^1.4.2", optional = true } -spacy = { version = ">=3.0.0,<=3.8.2", optional = true } -transformers = { version = "^4.45.0", optional = true } -torch = { version = "^2.5.0", optional = true } -keras = { version = ">=3.2.0", optional = true } -tf-keras = { version = ">=2.16.0", optional = true } +#spacy = { version = ">=3.0.0,<=3.8.2", optional = true } +#transformers = { version = "^4.45.0", optional = true } +#torch = { version = "^2.5.0", optional = true } +#keras = { version = ">=3.2.0", optional = true } +#tf-keras = { version = ">=2.16.0", optional = true } matplotlib = { version = ">=3.9.2", optional = true } [tool.poetry.extras] # Extras without versioning, grouped for specific use cases io = ["aiofiles", "aiohttp"] #llms = ["cohere", "mistralai", "fal-client", "google-generativeai", "openai"] -nlp = ["nltk", "textblob", "yake"] +nlp = [ + #"nltk", + #"textblob", + "yake" +] nlp_tools = ["beautifulsoup4"] -ml_toolkits = ["gensim", "scipy", "scikit-learn"] -spacy = ["spacy"] -transformers = ["transformers"] -torch = ["torch"] -tensorflow = ["keras", "tf-keras"] +#ml_toolkits = ["gensim", "scipy", "scikit-learn"] +ml_toolkits = ["scikit-learn"] +#spacy = ["spacy"] +#transformers = ["transformers"] +#torch = ["torch"] +#tensorflow = ["keras", "tf-keras"] visualization = ["matplotlib"] # Full option to install all extras full = [ "aiofiles", "aiohttp", #"cohere", "mistralai", "fal-client", "google-generativeai", "openai", - "nltk", "textblob", "yake", + #"nltk", "textblob", + "yake", "beautifulsoup4", - "gensim", "scipy", "scikit-learn", - "spacy", - "transformers", - "torch", - "keras", "tf-keras", + "scikit-learn", + #"gensim", "scipy", "scikit-learn", + #"spacy", + #"transformers", + #"torch", + #"keras", "tf-keras", "matplotlib" ] @@ -81,9 +97,11 @@ pytest = "^8.0" pytest-asyncio = ">=0.24.0" pytest-timeout = "^2.3.1" pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" python-dotenv = "^1.0.0" jsonschema = "^4.18.5" -ipython = "8.28.0" +ipython = "^8.28.0" +requests = "^2.32.3" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py b/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py index 651d9d992..8b75d563e 100644 --- a/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/agent_factories/concrete/__init__.py @@ -1,8 +1,22 @@ -from swarmauri.agent_factories.concrete.agent_factory import AgentFactory -from swarmauri.agent_factories.concrete.conf_driven_agent_factory import ( - ConfDrivenAgentFactory, -) -from JsonAgentFactory import JsonAgentFactory -from swarmauri.agent_factories.concrete.ReflectionAgentFactory import ( - ReflectionAgentFactory, -) +from swarmauri.utils._lazy_import import _lazy_import + +# List of agent factory names (file names without the ".py" extension) and corresponding class names +agent_factory_files = [ + ("swarmauri.agent_factories.concrete.agent_factory", "AgentFactory"), + ( + "swarmauri.agent_factories.concrete.conf_driven_agent_factory", + "ConfDrivenAgentFactory", + ), + ("swarmauri.agent_factories.concrete.JsonAgentFactory", "JsonAgentFactory"), + ( + "swarmauri.agent_factories.concrete.ReflectionAgentFactory", + "ReflectionAgentFactory", + ), +] + +# Lazy loading of agent factories storing them in variables +for module_name, class_name in agent_factory_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded agent factories to __all__ +__all__ = [class_name for _, class_name in agent_factory_files] diff --git a/pkgs/swarmauri/swarmauri/agents/concrete/SimpleConversationAgent.py b/pkgs/swarmauri/swarmauri/agents/concrete/SimpleConversationAgent.py index 9653cf9c6..3dcd0d4bf 100644 --- a/pkgs/swarmauri/swarmauri/agents/concrete/SimpleConversationAgent.py +++ b/pkgs/swarmauri/swarmauri/agents/concrete/SimpleConversationAgent.py @@ -16,7 +16,7 @@ class SimpleConversationAgent(AgentConversationMixin, AgentBase): def exec( self, - input_str: Optional[Union[str, List[contentItem]]] = "", + input_data: Optional[Union[str, List[contentItem]]] = "", llm_kwargs: Optional[Dict] = {}, ) -> Any: diff --git a/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py b/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py index f474dae0f..0103905cc 100644 --- a/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/agents/concrete/__init__.py @@ -1,27 +1,10 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - # If class is not found, print a warning message - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils._lazy_import import _lazy_import # List of agent names (file names without the ".py" extension) and corresponding class names agent_files = [ - ("swarmauri.agents.concrete.SimpleConversationAgent", "SimpleConversationAgent"), ("swarmauri.agents.concrete.QAAgent", "QAAgent"), ("swarmauri.agents.concrete.RagAgent", "RagAgent"), + ("swarmauri.agents.concrete.SimpleConversationAgent", "SimpleConversationAgent"), ("swarmauri.agents.concrete.ToolAgent", "ToolAgent"), ] diff --git a/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py b/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py index efdd73eff..d6e508040 100644 --- a/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/chains/concrete/__init__.py @@ -1,11 +1,15 @@ -from swarmauri.chains.concrete.CallableChain import CallableChain -from swarmauri.chains.concrete.ChainStep import ChainStep -from swarmauri.chains.concrete.PromptContextChain import PromptContextChain -from swarmauri.chains.concrete.ContextChain import ContextChain +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "CallableChain", - "ChainStep", - "PromptContextChain", - "ContextChain", +chains_files = [ + ("swarmauri.chains.concrete.CallableChain import", "CallableChain"), + ("swarmauri.chains.concrete.ChainStep", "ChainStep"), + ("swarmauri.chains.concrete.PromptContextChain", "PromptContextChain"), + ("swarmauri.chains.concrete.ContextChain", "ContextChain"), ] + +# Lazy loading of chain classes, storing them in variables +for module_name, class_name in chains_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded chain classes to __all__ +__all__ = [class_name for _, class_name in chains_files] diff --git a/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py b/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py index f894163ad..fafead5cf 100644 --- a/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/chunkers/concrete/__init__.py @@ -1,13 +1,17 @@ -from swarmauri.chunkers.concrete.DelimiterBasedChunker import DelimiterBasedChunker -from swarmauri.chunkers.concrete.FixedLengthChunker import FixedLengthChunker -from swarmauri.chunkers.concrete.MdSnippetChunker import MdSnippetChunker -from swarmauri.chunkers.concrete.SentenceChunker import SentenceChunker -from swarmauri.chunkers.concrete.SlidingWindowChunker import SlidingWindowChunker +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "DelimiterBasedChunker", - "FixedLengthChunker", - "MdSnippetChunker", - "SentenceChunker", - "SlidingWindowChunker", +# List of chunker names (file names without the ".py" extension) and corresponding class names +chunkers_files = [ + ("swarmauri.chunkers.concrete.DelimiterBasedChunker", "DelimiterBasedChunker"), + ("swarmauri.chunkers.concrete.FixedLengthChunker", "FixedLengthChunker"), + ("swarmauri.chunkers.concrete.MdSnippetChunker", "MdSnippetChunker"), + ("swarmauri.chunkers.concrete.SentenceChunker", "SentenceChunker"), + ("swarmauri.chunkers.concrete.SlidingWindowChunker", "SlidingWindowChunker"), ] + +# Lazy loading of chunker classes, storing them in variables +for module_name, class_name in chunkers_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded chunker classes to __all__ +__all__ = [class_name for _, class_name in chunkers_files] diff --git a/pkgs/swarmauri/swarmauri/control_panels/__init__.py b/pkgs/swarmauri/swarmauri/control_panels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/control_panels/base/ControlPanelBase.py b/pkgs/swarmauri/swarmauri/control_panels/base/ControlPanelBase.py new file mode 100644 index 000000000..2e36f48cc --- /dev/null +++ b/pkgs/swarmauri/swarmauri/control_panels/base/ControlPanelBase.py @@ -0,0 +1,99 @@ +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.control_panels.IControlPanel import IControlPlane +from typing import Any, List, Literal +from pydantic import Field, ConfigDict +from swarmauri.service_registries.base.ServiceRegistryBase import ServiceRegistryBase +from swarmauri.factories.base.FactoryBase import FactoryBase +from swarmauri.task_mgt_strategies.base.TaskMgtStrategyBase import TaskMgtStrategyBase +from swarmauri.transports.base.TransportBase import TransportBase +from swarmauri_core.typing import SubclassUnion +import logging + + +class ControlPanelBase(IControlPlane, ComponentBase): + """ + Implementation of the ControlPlane abstract class. + This class orchestrates agents, manages tasks, and ensures task distribution + and transport between agents and services. + """ + + resource: ResourceTypes = Field(default=ResourceTypes.CONTROL_PANEL.value) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["ControlPanelBase"] = "ControlPanelBase" + + agent_factory: SubclassUnion[FactoryBase] + service_registry: SubclassUnion[ServiceRegistryBase] + task_mgt_strategy: SubclassUnion[TaskMgtStrategyBase] + transport: SubclassUnion[TransportBase] + + # Agent management methods + def create_agent(self, name: str, role: str) -> Any: + """ + Create an agent with the given name and role, and register it in the service registry. + """ + agent = self.agent_factory.create_agent(name, role) + self.service_registry.register_service(name, {"role": role, "status": "active"}) + logging.info(f"Agent '{name}' with role '{role}' created and registered.") + return agent + + def remove_agent(self, name: str) -> None: + """ + Remove the agent with the specified name and unregister it from the service registry. + """ + agent = self.agent_factory.get_agent_by_name(name) + if not agent: + raise ValueError(f"Agent '{name}' not found.") + self.service_registry.unregister_service(name) + self.agent_factory.delete_agent(name) + logging.info(f"Agent '{name}' removed and unregistered.") + + def list_active_agents(self) -> List[str]: + """ + List all active agent names. + """ + agents = self.agent_factory.get_agents() + active_agents = [agent.name for agent in agents if agent] + logging.info(f"Active agents listed: {active_agents}") + return active_agents + + # Task management methods + def submit_tasks(self, tasks: List[Any]) -> None: + """ + Submit one or more tasks to the task management strategy for processing. + """ + for task in tasks: + self.task_mgt_strategy.add_task(task) + logging.info( + f"Task '{task.get('task_id', 'unknown')}' submitted to the strategy." + ) + + def process_tasks(self) -> None: + """ + Process and assign tasks from the queue, then transport them to their assigned services. + """ + try: + self.task_mgt_strategy.process_tasks( + self.service_registry.get_services, self.transport + ) + logging.info("Tasks processed and transported successfully.") + except Exception as e: + logging.error(f"Error while processing tasks: {e}") + raise ValueError(f"Error processing tasks: {e}") + + def distribute_tasks(self, task: Any) -> None: + """ + Distribute tasks using the task strategy (manual or on-demand assignment). + """ + self.task_mgt_strategy.assign_task(task, self.service_registry.get_services) + logging.info( + f"Task '{task.get('task_id', 'unknown')}' distributed to a service." + ) + + # Orchestration method + def orchestrate_agents(self, tasks: List[Any]) -> None: + """ + Orchestrate agents for task distribution and transportation. + """ + self.submit_tasks(tasks) # Add task to the strategy + self.process_tasks() # Process and transport the task + logging.info("Agents orchestrated successfully.") diff --git a/pkgs/swarmauri/swarmauri/control_panels/base/__init__.py b/pkgs/swarmauri/swarmauri/control_panels/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/control_panels/concrete/ControlPanel.py b/pkgs/swarmauri/swarmauri/control_panels/concrete/ControlPanel.py new file mode 100644 index 000000000..14a7b2205 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/control_panels/concrete/ControlPanel.py @@ -0,0 +1,10 @@ +from typing import Literal +from swarmauri.control_panels.base.ControlPanelBase import ControlPanelBase + + +class ControlPanel(ControlPanelBase): + """ + Concrete implementation of the ControlPanelBase. + """ + + type: Literal["ControlPanel"] = "ControlPanel" diff --git a/pkgs/swarmauri/swarmauri/control_panels/concrete/__init__.py b/pkgs/swarmauri/swarmauri/control_panels/concrete/__init__.py new file mode 100644 index 000000000..767127c45 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/control_panels/concrete/__init__.py @@ -0,0 +1,13 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of control_panels names (file names without the ".py" extension) and corresponding class names +control_panels_files = [ + ("swarmauri.control_panels.concrete.ControlPanel", "ControlPanel"), +] + +# Lazy loading of task_mgt_strategies classes, storing them in variables +for module_name, class_name in control_panels_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded state classes to __all__ +__all__ = [class_name for _, class_name in control_panels_files] diff --git a/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py b/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py index e51d24fe0..46179d6fb 100644 --- a/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/conversations/concrete/__init__.py @@ -1,15 +1,22 @@ -from swarmauri.conversations.concrete.Conversation import Conversation -from swarmauri.conversations.concrete.MaxSystemContextConversation import ( - MaxSystemContextConversation, -) -from swarmauri.conversations.concrete.MaxSizeConversation import MaxSizeConversation -from swarmauri.conversations.concrete.SessionCacheConversation import ( - SessionCacheConversation, -) +from swarmauri.utils._lazy_import import _lazy_import -__all__ = [ - "Conversation", - "MaxSystemContextConversation", - "MaxSizeConversation", - "SessionCacheConversation", +# List of conversations names (file names without the ".py" extension) and corresponding class names +conversations_files = [ + ("swarmauri.conversations.concrete.Conversation", "Conversation"), + ( + "swarmauri.conversations.concrete.MaxSystemContextConversation", + "MaxSystemContextConversation", + ), + ("swarmauri.conversations.concrete.MaxSizeConversation", "MaxSizeConversation"), + ( + "swarmauri.conversations.concrete.SessionCacheConversation", + "SessionCacheConversation", + ), ] + +# Lazy loading of conversations classes, storing them in variables +for module_name, class_name in conversations_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded conversations classes to __all__ +__all__ = [class_name for _, class_name in conversations_files] diff --git a/pkgs/swarmauri/swarmauri/dataconnectors/base/DataConnectorBase.py b/pkgs/swarmauri/swarmauri/dataconnectors/base/DataConnectorBase.py new file mode 100644 index 000000000..0cb70a202 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/dataconnectors/base/DataConnectorBase.py @@ -0,0 +1,79 @@ +from swarmauri_core.dataconnectors.IDataConnector import IDataConnector + + +class DataConnectorBase(IDataConnector): + """ + Base implementation of IDataConnector that raises NotImplementedError + for all abstract methods, ensuring explicit implementation in child classes. + """ + + def authenticate(self, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param kwargs: Authentication parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Authenticate method must be implemented by child classes." + ) + + def fetch_data(self, query: str, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param query: Query string or parameters + :param kwargs: Additional parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Fetch data method must be implemented by child classes." + ) + + def insert_data(self, data, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param data: Data to be inserted + :param kwargs: Additional parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Insert data method must be implemented by child classes." + ) + + def update_data(self, identifier, data, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param identifier: Unique identifier of the data to update + :param data: Updated data + :param kwargs: Additional parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Update data method must be implemented by child classes." + ) + + def delete_data(self, identifier, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param identifier: Unique identifier of the data to delete + :param kwargs: Additional parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Delete data method must be implemented by child classes." + ) + + def test_connection(self, **kwargs): + """ + Raises NotImplementedError to enforce implementation in child classes. + + :param kwargs: Connection parameters + :raises NotImplementedError: Always raised to require specific implementation + """ + raise NotImplementedError( + "Test connection method must be implemented by child classes." + ) diff --git a/pkgs/swarmauri/swarmauri/dataconnectors/base/__init__.py b/pkgs/swarmauri/swarmauri/dataconnectors/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/dataconnectors/concrete/GoogleDriveDataConnector.py b/pkgs/swarmauri/swarmauri/dataconnectors/concrete/GoogleDriveDataConnector.py new file mode 100644 index 000000000..57fa66a2a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/dataconnectors/concrete/GoogleDriveDataConnector.py @@ -0,0 +1,341 @@ +import logging +from urllib.parse import urlencode +import httpx +import base64 +import json +from typing import List +from swarmauri.documents.base.DocumentBase import DocumentBase +from swarmauri.dataconnectors.base.DataConnectorBase import DataConnectorBase + + +class GoogleDriveDataConnector(DataConnectorBase): + """ + Data connector for interacting with Google Drive files and converting them to Swarmauri documents. + + Supports authentication, data fetching, and basic CRUD operations for Google Drive resources. + """ + + def __init__(self, credentials_path: str = None): + """ + Initialize the Google Drive Data Connector. + + :param credentials_path: Path to the Google OAuth2 credentials JSON file + """ + with open(credentials_path, "r") as cred_file: + credentials = json.load(cred_file) + + self.client_id = credentials.get("client_id") + self.client_secret = credentials.get("client_secret") + self.redirect_uri = credentials.get("redirect_uri") + + # Tokens will be stored here + self.access_token = None + self.refresh_token = None + + self.authorization_code = None + + self.client = httpx.Client() + + def generate_authorization_url(self) -> str: + """Generate the authorization URL for user consent""" + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + "scope": "https://www.googleapis.com/auth/drive", + "access_type": "offline", # This ensures we get a refresh token + } + return f"https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}" + + def _exchange_code_for_tokens(self): + """Exchange authorization code for access and refresh tokens""" + if not self.authorization_code: + raise ValueError("No authorization code available") + + token_url = "https://oauth2.googleapis.com/token" + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": self.authorization_code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, + } + + response = self.client.post(token_url, data=payload) + tokens = response.json() + + logging.info(f"Token response: {tokens}") + if "access_token" not in tokens: + raise ValueError("Failed to obtain access token") + self.access_token = tokens["access_token"] + self.refresh_token = tokens.get("refresh_token") + + def refresh_access_token(self): + """Refresh the access token using the refresh token""" + if not self.refresh_token: + raise ValueError("No refresh token available") + + token_url = "https://oauth2.googleapis.com/token" + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + } + + response = self.client.post(token_url, data=payload) + tokens = response.json() + self.access_token = tokens["access_token"] + + def authenticate(self): + """ + Authenticate with Google Drive using OAuth2. + + This method generates an authorization URL, prompts the user to visit the URL + and enter the authorization code, and then exchanges the code for tokens. + """ + try: + # Generate authorization URL + auth_url = self.generate_authorization_url() + print("Please visit the following URL to authenticate:") + print(auth_url) + + # Prompt for authorization code + while True: + authorization_code = input("Enter the authorization code: ").strip() + + if not authorization_code: + print("Authorization code cannot be empty. Please try again.") + continue + + self.authorization_code = authorization_code + + try: + self._exchange_code_for_tokens() + logging.info("Successfully authenticated and obtained tokens") + return + except ValueError as e: + print(f"Error exchanging authorization code: {e}") + print("Please try again.") + self.authorization_code = None + + except Exception as e: + logging.error(f"Authentication failed: {e}") + raise ValueError(f"Authentication failed: {e}") + + def fetch_data(self, query: str = None, **kwargs) -> List[DocumentBase]: + """ + Fetch documents from Google Drive based on a query. + + :param query: Search query for files (optional) + :param kwargs: Additional parameters like mime_type, max_results + :return: List of Swarmauri Documents + """ + if not self.access_token: + raise ValueError("Not authenticated. Call authenticate() first.") + + try: + # Prepare request parameters + query_str = query or "" + mime_type = kwargs.get("mime_type", "application/vnd.google-apps.document") + max_results = kwargs.get("max_results", 100) + + # Construct request headers and parameters + headers = { + "Authorization": f"Bearer {self.access_token}", + "Accept": "application/json", + } + + params = { + "q": f"mimeType='{mime_type}' and name contains '{query_str}'", + "pageSize": max_results, + "fields": "files(id,name,mimeType)", + } + + # Make request to Google Drive API + response = self.client.get( + "https://www.googleapis.com/drive/v3/files", + headers=headers, + params=params, + ) + response.raise_for_status() + + files = response.json().get("files", []) + + # Convert Google Drive files to Swarmauri Documents + documents = [] + for file in files: + content = self._get_file_content(file["id"]) + document = DocumentBase( + content=content, + metadata={ + "id": file["id"], + "name": file["name"], + "mime_type": file["mimeType"], + }, + ) + documents.append(document) + + return documents + + except httpx.HTTPError as error: + raise ValueError(f"Error fetching Google Drive files: {error}") + + def _get_file_content(self, file_id: str) -> str: + """ + Retrieve text content from a Google Drive file. + + :param file_id: ID of the Google Drive file + :return: Text content of the file + """ + try: + # Prepare export request + headers = {"Authorization": f"Bearer {self.access_token}"} + + # Export file as plain text + export_url = f"https://www.googleapis.com/drive/v3/files/{file_id}/export" + params = {"mimeType": "text/plain"} + + response = self.client.get(export_url, headers=headers, params=params) + response.raise_for_status() + + return response.text + + except httpx.HTTPError as error: + print(f"An error occurred retrieving file content: {error}") + return "" + + def insert_data(self, data, **kwargs): + """ + Insert a new file into Google Drive. + + :param data: Content of the file to be inserted + :param kwargs: Additional metadata like filename, mime_type + :return: ID of the inserted file + """ + if not self.access_token: + raise ValueError("Not authenticated. Call authenticate() first.") + + try: + headers = { + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + } + + # Prepare file metadata + file_metadata = { + "name": kwargs.get("filename", "Untitled Document"), + "mimeType": kwargs.get( + "mime_type", "application/vnd.google-apps.document" + ), + } + + # Prepare file content (base64 encoded) + media_content = base64.b64encode(data.encode("utf-8")).decode("utf-8") + + # Construct payload + payload = { + "metadata": file_metadata, + "media": {"mimeType": "text/plain", "body": media_content}, + } + + # Make request to create file + response = self.client.post( + "https://www.googleapis.com/upload/drive/v3/files", + headers=headers, + json=payload, + ) + response.raise_for_status() + + return response.json().get("id") + + except httpx.HTTPError as error: + raise ValueError(f"Error inserting file: {error}") + + def update_data(self, identifier, data, **kwargs): + """ + Update an existing Google Drive file. + + :param identifier: File ID to update + :param data: New content for the file + :param kwargs: Additional update parameters + """ + if not self.access_token: + raise ValueError("Not authenticated. Call authenticate() first.") + + try: + headers = { + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + } + + # Prepare file content (base64 encoded) + media_content = base64.b64encode(data.encode("utf-8")).decode("utf-8") + + # Construct payload + payload = {"media": {"mimeType": "text/plain", "body": media_content}} + + # Make request to update file + response = self.client.patch( + f"https://www.googleapis.com/upload/drive/v3/files/{identifier}", + headers=headers, + json=payload, + ) + response.raise_for_status() + + except httpx.HTTPError as error: + raise ValueError(f"Error updating file: {error}") + + def delete_data(self, identifier, **kwargs): + """ + Delete a file from Google Drive. + + :param identifier: File ID to delete + """ + if not self.access_token: + raise ValueError("Not authenticated. Call authenticate() first.") + + try: + headers = {"Authorization": f"Bearer {self.access_token}"} + + response = self.client.delete( + f"https://www.googleapis.com/drive/v3/files/{identifier}", + headers=headers, + ) + response.raise_for_status() + + except httpx.HTTPError as error: + raise ValueError(f"Error deleting file: {error}") + + def test_connection(self, **kwargs): + """ + Test the connection to Google Drive by listing files. + + :return: Boolean indicating connection success + """ + try: + if not self.access_token: + self.authenticate(**kwargs) + + # Prepare request headers + headers = { + "Authorization": f"Bearer {self.access_token}", + "Accept": "application/json", + } + + # List first 5 files to test connection + params = {"pageSize": 5, "fields": "files(id,name)"} + + response = self.client.get( + "https://www.googleapis.com/drive/v3/files", + headers=headers, + params=params, + ) + response.raise_for_status() + + files = response.json().get("files", []) + return len(files) > 0 + + except Exception as e: + print(f"Connection test failed: {e}") + return False diff --git a/pkgs/swarmauri/swarmauri/dataconnectors/concrete/__init__.py b/pkgs/swarmauri/swarmauri/dataconnectors/concrete/__init__.py new file mode 100644 index 000000000..612d7f2e4 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/dataconnectors/concrete/__init__.py @@ -0,0 +1,16 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of data connector files names (file names without the ".py" extension) and corresponding class names +data_connector_files = [ + ( + "swarmauri.dataconnectors.concrete.GoogleDriveDataConnector", + "GoogleDriveDataConnector", + ), +] + +# Lazy loading of data connector classes, storing them in variables +for module_name, class_name in data_connector_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded data connector classes to __all__ +__all__ = [class_name for _, class_name in data_connector_files] diff --git a/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py b/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py index 9e163ca4d..033a0dd13 100644 --- a/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/distances/concrete/__init__.py @@ -1,34 +1,27 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of distance names (file names without the ".py" extension) -distance_files = [ - "CanberraDistance", - "ChebyshevDistance", - "ChiSquaredDistance", - "CosineDistance", - "EuclideanDistance", - "HaversineDistance", - "JaccardIndexDistance", - "LevenshteinDistance", - "ManhattanDistance", - "MinkowskiDistance", - "SorensenDiceDistance", - "SquaredEuclideanDistance", +# List of distances names (file names without the ".py" extension) and corresponding class names +distances_files = [ + ("swarmauri.distances.concrete.CanberraDistance", "CanberraDistance"), + ("swarmauri.distances.concrete.ChebyshevDistance", "ChebyshevDistance"), + ("swarmauri.distances.concrete.ChiSquaredDistance", "ChiSquaredDistance"), + ("swarmauri.distances.concrete.CosineDistance", "CosineDistance"), + ("swarmauri.distances.concrete.EuclideanDistance", "EuclideanDistance"), + ("swarmauri.distances.concrete.HaversineDistance", "HaversineDistance"), + ("swarmauri.distances.concrete.JaccardIndexDistance", "JaccardIndexDistance"), + ("swarmauri.distances.concrete.LevenshteinDistance", "LevenshteinDistance"), + ("swarmauri.distances.concrete.ManhattanDistance", "ManhattanDistance"), + ("swarmauri.distances.concrete.MinkowskiDistance", "MinkowskiDistance"), + ("swarmauri.distances.concrete.SorensenDiceDistance", "SorensenDiceDistance"), + ( + "swarmauri.distances.concrete.SquaredEuclideanDistance", + "SquaredEuclideanDistance", + ), ] -# Lazy loading of distance modules, storing them in variables -for distance in distance_files: - globals()[distance] = _lazy_import(f"swarmauri.distances.concrete.{distance}", distance) +# Lazy loading of distances classes, storing them in variables +for module_name, class_name in distances_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded distance modules to __all__ -__all__ = distance_files +# Adding the lazy-loaded distances classes to __all__ +__all__ = [class_name for _, class_name in distances_files] diff --git a/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py b/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py index f0725fde0..c4b50e1a5 100644 --- a/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/documents/concrete/__init__.py @@ -1 +1,11 @@ -from swarmauri.documents.concrete import * +from swarmauri.utils._lazy_import import _lazy_import + +# List of documents names (file names without the ".py" extension) and corresponding class names +documents_files = [("swarmauri.documents.concrete.Document", "Document")] + +# Lazy loading of documents classes, storing them in variables +for module_name, class_name in documents_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded documents classes to __all__ +__all__ = [class_name for _, class_name in documents_files] diff --git a/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py b/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py index a1f0f231c..c6d12f871 100644 --- a/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/embeddings/concrete/__init__.py @@ -1,31 +1,19 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None +# List of embeddings names (file names without the ".py" extension) and corresponding class names +embeddings_files = [ + ("swarmauri.embeddings.concrete.CohereEmbedding", "CohereEmbedding"), + ("swarmauri.embeddings.concrete.GeminiEmbedding", "GeminiEmbedding"), + ("swarmauri.embeddings.concrete.MistralEmbedding", "MistralEmbedding"), + ("swarmauri.embeddings.concrete.NmfEmbedding", "NmfEmbedding"), + ("swarmauri.embeddings.concrete.OpenAIEmbedding", "OpenAIEmbedding"), + ("swarmauri.embeddings.concrete.TfidfEmbedding", "TfidfEmbedding"), + ("swarmauri.embeddings.concrete.VoyageEmbedding", "VoyageEmbedding"), +] -# Lazy loading of embeddings with descriptive names -Doc2VecEmbedding = _lazy_import("swarmauri.embeddings.concrete.Doc2VecEmbedding", "Doc2VecEmbedding") -GeminiEmbedding = _lazy_import("swarmauri.embeddings.concrete.GeminiEmbedding", "GeminiEmbedding") -MistralEmbedding = _lazy_import("swarmauri.embeddings.concrete.MistralEmbedding", "MistralEmbedding") -MlmEmbedding = _lazy_import("swarmauri.embeddings.concrete.MlmEmbedding", "MlmEmbedding") -NmfEmbedding = _lazy_import("swarmauri.embeddings.concrete.NmfEmbedding", "NmfEmbedding") -OpenAIEmbedding = _lazy_import("swarmauri.embeddings.concrete.OpenAIEmbedding", "OpenAIEmbedding") -TfidfEmbedding = _lazy_import("swarmauri.embeddings.concrete.TfidfEmbedding", "TfidfEmbedding") +# Lazy loading of embeddings classes, storing them in variables +for module_name, class_name in embeddings_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding lazy-loaded modules to __all__ -__all__ = [ - "Doc2VecEmbedding", - "GeminiEmbedding", - "MistralEmbedding", - "MlmEmbedding", - "NmfEmbedding", - "OpenAIEmbedding", - "TfidfEmbedding", -] +# Adding the lazy-loaded embeddings classes to __all__ +__all__ = [class_name for _, class_name in embeddings_files] diff --git a/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py b/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py index 43b631bc1..2baf7a56d 100644 --- a/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/exceptions/concrete/__init__.py @@ -1,3 +1,13 @@ -from swarmauri.exceptions.concrete.IndexErrorWithContext import IndexErrorWithContext +from swarmauri.utils._lazy_import import _lazy_import -__all__ = ["IndexErrorWithContext"] +# List of exceptions names (file names without the ".py" extension) and corresponding class names +exceptions_files = [ + ("swarmauri.exceptions.concrete.IndexErrorWithContext", "IndexErrorWithContext"), +] + +# Lazy loading of exceptions classes, storing them in variables +for module_name, class_name in exceptions_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded exceptions classes to __all__ +__all__ = [class_name for _, class_name in exceptions_files] diff --git a/pkgs/swarmauri/swarmauri/factories/__init__.py b/pkgs/swarmauri/swarmauri/factories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/factories/base/FactoryBase.py b/pkgs/swarmauri/swarmauri/factories/base/FactoryBase.py new file mode 100644 index 000000000..ef1b9324e --- /dev/null +++ b/pkgs/swarmauri/swarmauri/factories/base/FactoryBase.py @@ -0,0 +1,31 @@ +from typing import Any, Callable, Literal, Optional +from swarmauri_core.ComponentBase import ComponentBase +from swarmauri_core.ComponentBase import ResourceTypes +from swarmauri_core.factories.IFactory import IFactory +from pydantic import ConfigDict, Field + + +class FactoryBase(IFactory, ComponentBase): + """ + Base factory class for registering and creating instances. + """ + + resource: Optional[str] = Field(default=ResourceTypes.FACTORY.value, frozen=True) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["FactoryBase"] = "FactoryBase" + + def register(self, type: str, resource_class: Callable) -> None: + """ + Register a resource class under a specific resource and type. + """ + raise NotImplementedError( + "register method must be implemented in derived classes." + ) + + def create(self, type: str, *args: Any, **kwargs: Any) -> Any: + """ + Create an instance of the class associated with the given resource and type. + """ + raise NotImplementedError( + "create method must be implemented in derived classes." + ) diff --git a/pkgs/swarmauri/swarmauri/factories/base/_init__.py b/pkgs/swarmauri/swarmauri/factories/base/_init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/factories/concrete/AgentFactory.py b/pkgs/swarmauri/swarmauri/factories/concrete/AgentFactory.py new file mode 100644 index 000000000..e641bf8f1 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/factories/concrete/AgentFactory.py @@ -0,0 +1,35 @@ +from typing import Any, Callable, Dict, Literal +from swarmauri.factories.base.FactoryBase import FactoryBase + + +class AgentFactory(FactoryBase): + """ + Class-specific factory for managing resources and types. + """ + + type: Literal["AgentFactory"] = "AgentFactory" + _registry: Dict[str, Callable] = {} + + def register(self, type: str, resource_class: Callable) -> None: + """ + Register a resource class with a specific type. + """ + if type in self._registry: + raise ValueError(f"Type '{type}' is already registered.") + self._registry[type] = resource_class + + def create(self, type: str, *args: Any, **kwargs: Any) -> Any: + """ + Create an instance of the class associated with the given type name. + """ + if type not in self._registry: + raise ValueError(f"Type '{type}' is not registered.") + + cls = self._registry[type] + return cls(*args, **kwargs) + + def get_agents(self): + """ + Return a list of registered agent types. + """ + return list(self._registry.keys()) diff --git a/pkgs/swarmauri/swarmauri/factories/concrete/Factory.py b/pkgs/swarmauri/swarmauri/factories/concrete/Factory.py new file mode 100644 index 000000000..9242ffd5d --- /dev/null +++ b/pkgs/swarmauri/swarmauri/factories/concrete/Factory.py @@ -0,0 +1,42 @@ +from typing import Any, Callable, Dict, Literal +from swarmauri.factories.base.FactoryBase import FactoryBase +from swarmauri.utils._get_subclasses import get_classes_from_module + + +class Factory(FactoryBase): + """ + Non-recursive factory extending FactoryBase. + """ + + type: Literal["Factory"] = "Factory" + _resource_registry: Dict[str, Dict[str, Callable]] = {} + + def register(self, resource: str, type: str, resource_class: Callable) -> None: + """ + Register a resource class under a specific resource. + """ + if type in self._resource_registry.get(resource, {}): + raise ValueError( + f"Type '{type}' is already registered under resource '{resource}'." + ) + + if resource not in self._resource_registry: + self._resource_registry[resource] = get_classes_from_module(resource) + + if type not in self._resource_registry[resource]: + self._resource_registry[resource][type] = resource_class + + def create(self, resource: str, type: str, *args: Any, **kwargs: Any) -> Any: + """ + Create an instance of the class associated with the given resource and type. + """ + if resource not in self._resource_registry: + self._resource_registry[resource] = get_classes_from_module(resource) + + if type not in self._resource_registry[resource]: + raise ValueError( + f"Type '{type}' is not registered under resource '{resource}'." + ) + + cls = self._resource_registry[resource][type] + return cls(*args, **kwargs) diff --git a/pkgs/swarmauri/swarmauri/factories/concrete/__init__.py b/pkgs/swarmauri/swarmauri/factories/concrete/__init__.py new file mode 100644 index 000000000..014b941f2 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/factories/concrete/__init__.py @@ -0,0 +1,14 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of swarms names (file names without the ".py" extension) and corresponding class names +factories_files = [ + ("swarmauri.factories.concrete.Factory", "Factory"), + ("swarmauri.factories.concrete.AgentFactory", "AgentFactory"), +] + +# Lazy loading of swarms classes, storing them in variables +for module_name, class_name in factories_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded swarms classes to __all__ +__all__ = [class_name for _, class_name in factories_files] diff --git a/pkgs/swarmauri/swarmauri/image_gens/__init_.py b/pkgs/swarmauri/swarmauri/image_gens/__init_.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py b/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py new file mode 100644 index 000000000..ab3dff1f2 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/image_gens/base/ImageGenBase.py @@ -0,0 +1,51 @@ +from abc import abstractmethod +from typing import Optional, List, Literal +from pydantic import ConfigDict, model_validator, Field +from swarmauri_core.image_gens.IGenImage import IGenImage +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes + + +class ImageGenBase(IGenImage, ComponentBase): + allowed_models: List[str] = [] + resource: Optional[str] = Field(default=ResourceTypes.IMAGE_GEN.value, frozen=True) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["ImageGenBase"] = "ImageGenBase" + + @model_validator(mode="after") + @classmethod + def _validate_name_in_allowed_models(cls, values): + name = values.name + allowed_models = values.allowed_models + if name and name not in allowed_models: + raise ValueError( + f"Model name {name} is not allowed. Choose from {allowed_models}" + ) + return values + + @abstractmethod + def generate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("generate_image() not implemented in subclass yet.") + + @abstractmethod + async def agenerate_image(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("agenerate_image() not implemented in subclass yet.") + + @abstractmethod + def batch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("batch_generate() not implemented in subclass yet.") + + @abstractmethod + async def abatch_generate(self, *args, **kwargs) -> any: + """ + Generate images based on the input data provided to the model. + """ + raise NotImplementedError("abatch_generate() not implemented in subclass yet.") diff --git a/pkgs/swarmauri/swarmauri/image_gens/base/__init__.py b/pkgs/swarmauri/swarmauri/image_gens/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py similarity index 96% rename from pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py index 50d395394..d783c4204 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/BlackForestImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/BlackForestImgGenModel.py @@ -1,35 +1,35 @@ import httpx import time -from typing import List, Literal, Optional, Dict, ClassVar +from typing import List, Literal, Optional, Dict from pydantic import PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import asyncio import contextlib -class BlackForestImgGenModel(LLMBase): +class BlackForestImgGenModel(ImageGenBase): """ A model for generating images using FluxPro's image generation models through the Black Forest API. Link to API key: https://api.bfl.ml/auth/profile """ _BASE_URL: str = PrivateAttr("https://api.bfl.ml") - _client: httpx.Client = PrivateAttr() + _client: httpx.Client = PrivateAttr(default=None) _async_client: httpx.AsyncClient = PrivateAttr(default=None) + _headers: Dict[str, str] = PrivateAttr(default=None) api_key: str allowed_models: List[str] = ["flux-pro-1.1", "flux-pro", "flux-dev"] - asyncio: ClassVar = asyncio name: str = "flux-pro" # Default model type: Literal["BlackForestImgGenModel"] = "BlackForestImgGenModel" - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the BlackForestImgGenModel instance with HTTP clients. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "X-Key": self.api_key, diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py similarity index 97% rename from pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py index 56afc2105..5dcc35d98 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/DeepInfraImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/DeepInfraImgGenModel.py @@ -2,12 +2,12 @@ from typing import List, Literal from pydantic import PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import asyncio import contextlib -class DeepInfraImgGenModel(LLMBase): +class DeepInfraImgGenModel(ImageGenBase): """ A model class for generating images from text prompts using DeepInfra's image generation API. @@ -37,7 +37,7 @@ class DeepInfraImgGenModel(LLMBase): name: str = "stabilityai/stable-diffusion-2-1" # Default model type: Literal["DeepInfraImgGenModel"] = "DeepInfraImgGenModel" - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the DeepInfraImgGenModel instance. @@ -47,7 +47,7 @@ def __init__(self, **data): Args: **data: Keyword arguments for model initialization. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py similarity index 98% rename from pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py index 6943d1e59..b6eadb3b7 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/FalAIImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/FalAIImgGenModel.py @@ -3,11 +3,11 @@ from typing import List, Literal, Optional, Dict from pydantic import Field, PrivateAttr from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase import time -class FalAIImgGenModel(LLMBase): +class FalAIImgGenModel(ImageGenBase): """ A model class for generating images from text using FluxPro's image generation model, provided by FalAI. This class uses a queue-based API to handle image generation requests. @@ -34,7 +34,7 @@ class FalAIImgGenModel(LLMBase): max_retries: int = Field(default=60) # Maximum number of status check retries retry_delay: float = Field(default=1.0) # Delay between status checks in seconds - def __init__(self, **data): + def __init__(self, **kwargs): """ Initializes the model with the specified API key and model name. @@ -44,7 +44,7 @@ def __init__(self, **data): Raises: ValueError: If an invalid model name is provided. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Content-Type": "application/json", "Authorization": f"Key {self.api_key}", diff --git a/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py new file mode 100644 index 000000000..43f7dd60b --- /dev/null +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/HyperbolicImgGenModel.py @@ -0,0 +1,210 @@ +import httpx +from typing import List, Literal +from pydantic import PrivateAttr +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase +import asyncio +import contextlib + + +class HyperbolicImgGenModel(ImageGenBase): + """ + A model class for generating images from text prompts using Hyperbolic's image generation API. + + Attributes: + api_key (str): The API key for authenticating with the Hyperbolic API. + allowed_models (List[str]): A list of available models for image generation. + asyncio (ClassVar): The asyncio module for handling asynchronous operations. + name (str): The name of the model to be used for image generation. + type (Literal["HyperbolicImgGenModel"]): The type identifier for the model class. + height (int): Height of the generated image. + width (int): Width of the generated image. + steps (int): Number of inference steps. + cfg_scale (float): Classifier-free guidance scale. + enable_refiner (bool): Whether to enable the refiner model. + backend (str): Computational backend for the model. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + _BASE_URL: str = PrivateAttr("https://api.hyperbolic.xyz/v1/image/generation") + _client: httpx.Client = PrivateAttr(default=None) + _async_client: httpx.AsyncClient = PrivateAttr(default=None) + + api_key: str + allowed_models: List[str] = [ + "SDXL1.0-base", + "SD1.5", + "SSD", + "SD2", + "SDXL-turbo", + ] + + name: str = "SDXL1.0-base" # Default model + type: Literal["HyperbolicImgGenModel"] = "HyperbolicImgGenModel" + + # Additional configuration parameters + height: int = 1024 + width: int = 1024 + steps: int = 30 + cfg_scale: float = 5.0 + enable_refiner: bool = False + backend: str = "auto" + + def __init__(self, **kwargs): + """ + Initializes the HyperbolicImgGenModel instance. + + This constructor sets up HTTP clients for both synchronous and asynchronous + operations and configures request headers with the provided API key. + + Args: + **data: Keyword arguments for model initialization. + """ + super().__init__(**kwargs) + self._headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self._client = httpx.Client(headers=self._headers, timeout=30) + + async def _get_async_client(self) -> httpx.AsyncClient: + """ + Gets or creates an async client instance. + """ + if self._async_client is None or self._async_client.is_closed: + self._async_client = httpx.AsyncClient(headers=self._headers, timeout=30) + return self._async_client + + async def _close_async_client(self): + """ + Closes the async client if it exists and is open. + """ + if self._async_client is not None and not self._async_client.is_closed: + await self._async_client.aclose() + self._async_client = None + + def _create_request_payload(self, prompt: str) -> dict: + """ + Creates the payload for the image generation request. + """ + return { + "model_name": self.name, + "prompt": prompt, + "height": self.height, + "width": self.width, + "steps": self.steps, + "cfg_scale": self.cfg_scale, + "enable_refiner": self.enable_refiner, + "backend": self.backend, + } + + @retry_on_status_codes((429, 529), max_retries=1) + def _send_request(self, prompt: str) -> dict: + """ + Sends a synchronous request to the Hyperbolic API for image generation. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + dict: The response data from the API. + """ + payload = self._create_request_payload(prompt) + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + return response.json() + + @retry_on_status_codes((429, 529), max_retries=1) + async def _async_send_request(self, prompt: str) -> dict: + """ + Sends an asynchronous request to the Hyperbolic API for image generation. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + dict: The response data from the API. + """ + client = await self._get_async_client() + payload = self._create_request_payload(prompt) + response = await client.post(self._BASE_URL, json=payload) + response.raise_for_status() + return response.json() + + def generate_image_base64(self, prompt: str) -> str: + """ + Generates an image synchronously based on the provided prompt and returns it as a base64-encoded string. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + str: The base64-encoded representation of the generated image. + """ + response_data = self._send_request(prompt) + return response_data["images"][0]["image"] + + async def agenerate_image_base64(self, prompt: str) -> str: + """ + Generates an image asynchronously based on the provided prompt and returns it as a base64-encoded string. + + Args: + prompt (str): The text prompt used for generating the image. + + Returns: + str: The base64-encoded representation of the generated image. + """ + try: + response_data = await self._async_send_request(prompt) + return response_data["images"][0]["image"] + finally: + await self._close_async_client() + + def batch_base64(self, prompts: List[str]) -> List[str]: + """ + Generates images for a batch of prompts synchronously and returns them as a list of base64-encoded strings. + + Args: + prompts (List[str]): A list of text prompts for image generation. + + Returns: + List[str]: A list of base64-encoded representations of the generated images. + """ + return [self.generate_image_base64(prompt) for prompt in prompts] + + async def abatch_base64( + self, prompts: List[str], max_concurrent: int = 5 + ) -> List[str]: + """ + Generates images for a batch of prompts asynchronously and returns them as a list of base64-encoded strings. + + Args: + prompts (List[str]): A list of text prompts for image generation. + max_concurrent (int): The maximum number of concurrent tasks. + + Returns: + List[str]: A list of base64-encoded representations of the generated images. + """ + try: + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_prompt(prompt): + async with semaphore: + response_data = await self._async_send_request(prompt) + return response_data["images"][0]["image"] + + tasks = [process_prompt(prompt) for prompt in prompts] + return await asyncio.gather(*tasks) + finally: + await self._close_async_client() + + def __del__(self): + """ + Cleanup method to ensure clients are closed. + """ + self._client.close() + if self._async_client is not None and not self._async_client.is_closed: + with contextlib.suppress(Exception): + asyncio.run(self._close_async_client()) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py similarity index 97% rename from pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py rename to pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py index ad78fd7d8..8862799e5 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py +++ b/pkgs/swarmauri/swarmauri/image_gens/concrete/OpenAIImgGenModel.py @@ -2,11 +2,11 @@ import asyncio import httpx from typing import Dict, List, Literal, Optional +from swarmauri.image_gens.base.ImageGenBase import ImageGenBase from swarmauri.utils.retry_decorator import retry_on_status_codes -from swarmauri.llms.base.LLMBase import LLMBase -class OpenAIImgGenModel(LLMBase): +class OpenAIImgGenModel(ImageGenBase): """ OpenAIImgGenModel is a class for generating images using OpenAI's DALL-E models. @@ -26,14 +26,14 @@ class OpenAIImgGenModel(LLMBase): _BASE_URL: str = PrivateAttr(default="https://api.openai.com/v1/images/generations") _headers: Dict[str, str] = PrivateAttr(default=None) - def __init__(self, **data) -> None: + def __init__(self, **kwargs) -> None: """ Initialize the GroqAIAudio class with the provided data. Args: **data: Arbitrary keyword arguments containing initialization data. """ - super().__init__(**data) + super().__init__(**kwargs) self._headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", diff --git a/pkgs/swarmauri/swarmauri/image_gens/concrete/__init__.py b/pkgs/swarmauri/swarmauri/image_gens/concrete/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py new file mode 100644 index 000000000..b6ff4de67 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicAudioTTS.py @@ -0,0 +1,146 @@ +import base64 +import io +import os +from typing import AsyncIterator, Iterator, List, Literal, Dict, Optional +import httpx +from pydantic import PrivateAttr, model_validator, Field +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.llms.base.LLMBase import LLMBase +import asyncio + + +class HyperbolicAudioTTS(LLMBase): + """ + A class to interact with Hyperbolic's Text-to-Speech API, allowing for synchronous + and asynchronous text-to-speech synthesis. + + Attributes: + api_key (str): The API key for accessing Hyperbolic's TTS service. + language (Optional[str]): Language of the text. + speaker (Optional[str]): Specific speaker variant. + speed (Optional[float]): Speech speed control. + + Provider Resource: https://api.hyperbolic.xyz/v1/audio/generation + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + + # Supported languages + allowed_languages: List[str] = ["EN", "ES", "FR", "ZH", "JP", "KR"] + + # Supported speakers per language + allowed_speakers: Dict[str, List[str]] = { + "EN": ["EN-US", "EN-BR", "EN-INDIA", "EN-AU"], + "ES": ["ES"], + "FR": ["FR"], + "ZH": ["ZH"], + "JP": ["JP"], + "KR": ["KR"], + } + + # Optional parameters with type hints and validation + language: Optional[str] = None + speaker: Optional[str] = None + speed: Optional[float] = Field(default=1.0, ge=0.1, le=5) + + type: Literal["HyperbolicAudioTTS"] = "HyperbolicAudioTTS" + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/audio/generation" + ) + _headers: Dict[str, str] = PrivateAttr(default=None) + + def __init__(self, **data): + """ + Initialize the HyperbolicAudioTTS class with the provided data. + """ + super().__init__(**data) + self._headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def _prepare_payload(self, text: str) -> Dict: + """ + Prepare the payload for the TTS request. + """ + payload = {"text": text} + + # Add optional parameters if they are set + if self.language: + payload["language"] = self.language + if self.speaker: + payload["speaker"] = self.speaker + if self.speed is not None: + payload["speed"] = self.speed + + return payload + + def predict(self, text: str, audio_path: str = "output.mp3") -> str: + """ + Synchronously converts text to speech. + """ + payload = self._prepare_payload(text) + + with httpx.Client(timeout=30) as client: + response = client.post(self._BASE_URL, headers=self._headers, json=payload) + response.raise_for_status() + + # Decode base64 audio + audio_data = base64.b64decode(response.json()["audio"]) + + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_data) + + return os.path.abspath(audio_path) + + async def apredict(self, text: str, audio_path: str = "output.mp3") -> str: + """ + Asynchronously converts text to speech. + """ + payload = self._prepare_payload(text) + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + # Decode base64 audio + audio_data = base64.b64decode(response.json()["audio"]) + + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_data) + + return os.path.abspath(audio_path) + + def batch( + self, + text_path_dict: Dict[str, str], + ) -> List[str]: + """ + Synchronously process multiple text-to-speech requests in batch mode. + """ + return [ + self.predict(text=text, audio_path=path) + for text, path in text_path_dict.items() + ] + + async def abatch( + self, + text_path_dict: Dict[str, str], + max_concurrent=5, + ) -> List[str]: + """ + Asynchronously process multiple text-to-speech requests in batch mode. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(text, path) -> str: + async with semaphore: + return await self.apredict(text=text, audio_path=path) + + tasks = [ + process_conversation(text, path) for text, path in text_path_dict.items() + ] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py new file mode 100644 index 000000000..2b3677677 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicModel.py @@ -0,0 +1,427 @@ +import asyncio +import json +from pydantic import PrivateAttr +import httpx +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.utils.duration_manager import DurationManager +from swarmauri.conversations.concrete.Conversation import Conversation +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator + +from swarmauri_core.typing import SubclassUnion +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase + +from swarmauri.messages.concrete.AgentMessage import UsageData + + +class HyperbolicModel(LLMBase): + """ + HyperbolicModel class for interacting with the Hyperbolic AI language models API. + + Attributes: + api_key (str): API key for authenticating requests to the Hyperbolic API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["HyperbolicModel"]): The type identifier for this class. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + allowed_models: List[str] = [ + "Qwen/Qwen2.5-Coder-32B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", + "deepseek-ai/DeepSeek-V2.5", + "meta-llama/Meta-Llama-3-70B-Instruct", + "NousResearch/Hermes-3-Llama-3.1-70B", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + ] + name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" + type: Literal["HyperbolicModel"] = "HyperbolicModel" + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/chat/completions" + ) + _headers: Dict[str, str] = PrivateAttr(default=None) + + def __init__(self, **data) -> None: + """ + Initialize the HyperbolicModel class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_message["content"] = [ + {"type": item["type"], **item} + for item in formatted_message["content"] + ] + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data( + self, + usage_data, + prompt_time: float = 0.0, + completion_time: float = 0.0, + ) -> UsageData: + """ + Prepare usage data by combining token counts and timing information. + + Args: + usage_data: Raw usage data containing token counts. + prompt_time (float): Time taken for prompt processing. + completion_time (float): Time taken for response completion. + + Returns: + UsageData: Processed usage data. + """ + total_time = prompt_time + completion_time + + # Filter usage data for relevant keys + filtered_usage_data = { + key: value + for key, value in usage_data.items() + if key + not in { + "prompt_tokens", + "completion_tokens", + "total_tokens", + "prompt_time", + "completion_time", + "total_time", + } + } + + usage = UsageData( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + prompt_time=prompt_time, + completion_time=completion_time, + total_time=total_time, + **filtered_usage_data, + ) + + return usage + + @retry_on_status_codes((429, 529), max_retries=1) + def predict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (Optional[int]): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + top_k (int): Maximum number of tokens to consider at each step. + enable_json (bool): Whether to format the response as JSON. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": False, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + with httpx.Client(timeout=30) as client: + response = client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + response_data = response.json() + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data, promt_timer.duration) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + async def apredict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args are same as predict method. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": False, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + response_data = response.json() + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data, promt_timer.duration) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + def stream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args are same as predict method. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + with httpx.Client(timeout=30) as client: + response = client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + message_content = "" + usage_data = {} + with DurationManager() as completion_timer: + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"] and chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + if "usage" in chunk and chunk["usage"] is not None: + usage_data = chunk["usage"] + except json.JSONDecodeError: + pass + + usage = self._prepare_usage_data( + usage_data, promt_timer.duration, completion_timer.duration + ) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + + @retry_on_status_codes((429, 529), max_retries=1) + async def astream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args are same as predict method. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop is not None: + payload["stop"] = stop + + with DurationManager() as promt_timer: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + self._BASE_URL, headers=self._headers, json=payload + ) + response.raise_for_status() + + message_content = "" + usage_data = {} + with DurationManager() as completion_timer: + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"] and chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + if "usage" in chunk and chunk["usage"] is not None: + usage_data = chunk["usage"] + except json.JSONDecodeError: + pass + + usage = self._prepare_usage_data( + usage_data, promt_timer.duration, completion_timer.duration + ) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args are same as predict method. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + enable_json=enable_json, + stop=stop, + ) + results.append(result_conversation) + return results + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + top_k: int = -1, + enable_json: bool = False, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args are same as predict method, with additional arg: + max_concurrent (int): Maximum number of concurrent requests. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv: Conversation) -> Conversation: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + enable_json=enable_json, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py new file mode 100644 index 000000000..14e2d196a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/llms/concrete/HyperbolicVisionModel.py @@ -0,0 +1,381 @@ +import json +from pydantic import PrivateAttr +import httpx +from typing import List, Optional, Dict, Literal, Any, AsyncGenerator, Generator +import asyncio + +from swarmauri_core.typing import SubclassUnion +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.base.MessageBase import MessageBase +from swarmauri.messages.concrete.AgentMessage import AgentMessage +from swarmauri.llms.base.LLMBase import LLMBase +from swarmauri.messages.concrete.AgentMessage import UsageData +from swarmauri.utils.retry_decorator import retry_on_status_codes +from swarmauri.utils.file_path_to_base64 import file_path_to_base64 + + +class HyperbolicVisionModel(LLMBase): + """ + HyperbolicVisionModel class for interacting with the Hyperbolic vision language models API. This class + provides synchronous and asynchronous methods to send conversation data to the + model, receive predictions, and stream responses. + + Attributes: + api_key (str): API key for authenticating requests to the Hyperbolic API. + allowed_models (List[str]): List of allowed model names that can be used. + name (str): The default model name to use for predictions. + type (Literal["HyperbolicVisionModel"]): The type identifier for this class. + + Link to Allowed Models: https://app.hyperbolic.xyz/models + Link to API KEYS: https://app.hyperbolic.xyz/settings + """ + + api_key: str + allowed_models: List[str] = [ + "Qwen/Qwen2-VL-72B-Instruct", + "mistralai/Pixtral-12B-2409", + "Qwen/Qwen2-VL-7B-Instruct", + ] + name: str = "Qwen/Qwen2-VL-72B-Instruct" + type: Literal["HyperbolicVisionModel"] = "HyperbolicVisionModel" + _headers: Dict[str, str] = PrivateAttr(default=None) + _client: httpx.Client = PrivateAttr(default=None) + _BASE_URL: str = PrivateAttr( + default="https://api.hyperbolic.xyz/v1/chat/completions" + ) + + def __init__(self, **data): + """ + Initialize the HyperbolicVisionModel class with the provided data. + + Args: + **data: Arbitrary keyword arguments containing initialization data. + """ + super().__init__(**data) + self._headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + self._client = httpx.Client( + headers=self._headers, + base_url=self._BASE_URL, + ) + + def _format_messages( + self, + messages: List[SubclassUnion[MessageBase]], + ) -> List[Dict[str, Any]]: + """ + Formats conversation messages into the structure expected by the API. + + Args: + messages (List[MessageBase]): List of message objects from the conversation history. + + Returns: + List[Dict[str, Any]]: List of formatted message dictionaries. + """ + formatted_messages = [] + for message in messages: + formatted_message = message.model_dump( + include=["content", "role", "name"], exclude_none=True + ) + + if isinstance(formatted_message["content"], list): + formatted_content = [] + for item in formatted_message["content"]: + if item["type"] == "image_url" and "file_path" in item: + # Convert file path to base64 + base64_img = file_path_to_base64(item["file_path"]) + formatted_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_img}" + }, + } + ) + else: + formatted_content.append(item) + formatted_message["content"] = formatted_content + + formatted_messages.append(formatted_message) + return formatted_messages + + def _prepare_usage_data(self, usage_data) -> UsageData: + """ + Prepares and validates usage data received from the API response. + + Args: + usage_data (dict): Raw usage data from the API response. + + Returns: + UsageData: Validated usage data instance. + """ + return UsageData.model_validate(usage_data) + + @retry_on_status_codes((429, 529), max_retries=1) + def predict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Generates a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + async def apredict( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Conversation: + """ + Async method to generate a response from the model based on the given conversation. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + Conversation: Updated conversation with the model's response. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop or [], + } + + async with httpx.AsyncClient() as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + response_data = response.json() + + message_content = response_data["choices"][0]["message"]["content"] + usage_data = response_data.get("usage", {}) + + usage = self._prepare_usage_data(usage_data) + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + @retry_on_status_codes((429, 529), max_retries=1) + def stream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Generator[str, None, None]: + """ + Streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + response = self._client.post(self._BASE_URL, json=payload) + response.raise_for_status() + + message_content = "" + for line in response.iter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + @retry_on_status_codes((429, 529), max_retries=1) + async def astream( + self, + conversation: Conversation, + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[str, None]: + """ + Async generator that streams response text from the model in real-time. + + Args: + conversation (Conversation): Conversation object with message history. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for the model's response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Yields: + str: Partial response content from the model. + """ + formatted_messages = self._format_messages(conversation.history) + payload = { + "model": self.name, + "messages": formatted_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + "stop": stop or [], + } + + async with httpx.AsyncClient as async_client: + response = await async_client.post( + self._BASE_URL, json=payload, headers=self._headers + ) + response.raise_for_status() + + message_content = "" + async for line in response.aiter_lines(): + json_str = line.replace("data: ", "") + try: + if json_str: + chunk = json.loads(json_str) + if chunk["choices"][0]["delta"]: + delta = chunk["choices"][0]["delta"]["content"] + message_content += delta + yield delta + except json.JSONDecodeError: + pass + + conversation.add_message(AgentMessage(content=message_content)) + + def batch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> List[Conversation]: + """ + Processes a batch of conversations and generates responses for each sequentially. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + results = [] + for conversation in conversations: + result_conversation = self.predict( + conversation, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + results.append(result_conversation) + return results + + async def abatch( + self, + conversations: List[Conversation], + temperature: float = 0.7, + max_tokens: int = 2048, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + max_concurrent=5, + ) -> List[Conversation]: + """ + Async method for processing a batch of conversations concurrently. + + Args: + conversations (List[Conversation]): List of conversations to process. + temperature (float): Sampling temperature for response diversity. + max_tokens (int): Maximum tokens for each response. + top_p (float): Cumulative probability for nucleus sampling. + stop (Optional[List[str]]): List of stop sequences for response termination. + max_concurrent (int): Maximum number of concurrent requests. + + Returns: + List[Conversation]: List of updated conversations with model responses. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv: Conversation) -> Conversation: + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py b/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py index a24e7b59f..975ac7e93 100644 --- a/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/llms/concrete/__init__.py @@ -1,47 +1,43 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of model names (file names without the ".py" extension) -model_files = [ - "AI21StudioModel", - "AnthropicModel", - "AnthropicToolModel", - "BlackForestimgGenModel", - "CohereModel", - "CohereToolModel", - "DeepInfraImgGenModel", - "DeepInfraModel", - "DeepSeekModel", - "FalAllImgGenModel", - "FalAVisionModel", - "GeminiProModel", - "GeminiToolModel", - "GroqAudio", - "GroqModel", - "GroqToolModel", - "GroqVisionModel", - "MistralModel", - "MistralToolModel", - "OpenAIGenModel", - "OpenAIModel", - "OpenAIToolModel", - "PerplexityModel", - "PlayHTModel", - "WhisperLargeModel", +# List of llms names (file names without the ".py" extension) and corresponding class names +llms_files = [ + ("swarmauri.llms.concrete.AI21StudioModel", "AI21StudioModel"), + ("swarmauri.llms.concrete.AnthropicModel", "AnthropicModel"), + ("swarmauri.llms.concrete.AnthropicToolModel", "AnthropicToolModel"), + ("swarmauri.llms.concrete.BlackForestImgGenModel", "BlackForestImgGenModel"), + ("swarmauri.llms.concrete.CohereModel", "CohereModel"), + ("swarmauri.llms.concrete.CohereToolModel", "CohereToolModel"), + ("swarmauri.llms.concrete.DeepInfraImgGenModel", "DeepInfraImgGenModel"), + ("swarmauri.llms.concrete.DeepInfraModel", "DeepInfraModel"), + ("swarmauri.llms.concrete.DeepSeekModel", "DeepSeekModel"), + ("swarmauri.llms.concrete.FalAIImgGenModel", "FalaiImgGenModel"), + ("swarmauri.llms.concrete.FalAIVisionModel", "FalAIVisionModel"), + ("swarmauri.llms.concrete.GeminiProModel", "GeminiProModel"), + ("swarmauri.llms.concrete.GeminiToolModel", "GeminiToolModel"), + ("swarmauri.llms.concrete.GroqAIAudio", "GroqAIAudio"), + ("swarmauri.llms.concrete.GroqModel", "GroqModel"), + ("swarmauri.llms.concrete.GroqToolModel", "GroqToolModel"), + ("swarmauri.llms.concrete.GroqVisionModel", "GroqVisionModel"), + ("swarmauri.llms.concrete.HyperbolicAudioTTS", "HyperbolicAudioTTS"), + ("swarmauri.llms.concrete.HyperbolicImgGenModel", "HyperbolicImgGenModel"), + ("swarmauri.llms.concrete.HyperbolicModel", "HyperbolicModel"), + ("swarmauri.llms.concrete.HyperbolicVisionModel", "HyperbolicVisionModel"), + ("swarmauri.llms.concrete.MistralModel", "MistralModel"), + ("swarmauri.llms.concrete.MistralToolModel", "MistralToolModel"), + ("swarmauri.llms.concrete.OpenAIAudio", "OpenAIAudio"), + ("swarmauri.llms.concrete.OpenAIAudioTTS", "OpenAIAudioTTS"), + ("swarmauri.llms.concrete.OpenAIImgGenModel", "OpenAIImgGenModel"), + ("swarmauri.llms.concrete.OpenAIModel", "OpenAIModel"), + ("swarmauri.llms.concrete.OpenAIToolModel", "OpenAIToolModel"), + ("swarmauri.llms.concrete.PerplexityModel", "PerplexityModel"), + ("swarmauri.llms.concrete.PlayHTModel", "PlayHTModel"), + ("swarmauri.llms.concrete.WhisperLargeModel", "WhisperLargeModel"), ] -# Lazy loading of models, storing them in variables -for model in model_files: - globals()[model] = _lazy_import(f"swarmauri.llms.concrete.{model}", model) +# Lazy loading of llms classes, storing them in variables +for module_name, class_name in llms_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded models to __all__ -__all__ = model_files +# Adding the lazy-loaded llms classes to __all__ +__all__ = [class_name for _, class_name in llms_files] diff --git a/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py b/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py index e340b2b85..ea47cb17f 100644 --- a/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/measurements/concrete/__init__.py @@ -1,6 +1,41 @@ -from swarmauri.measurements.concrete.FirstImpressionMeasurement import FirstImpressionMeasurement -from swarmauri.measurements.concrete.MeanMeasurement import MeanMeasurement -from swarmauri.measurements.concrete.PatternMatchingMeasurement import PatternMatchingMeasurement -from swarmauri.measurements.concrete.RatioOfSumsMeasurement import RatioOfSumsMeasurement -from swarmauri.measurements.concrete.StaticMeasurement import StaticMeasurement -from swarmauri.measurements.concrete.ZeroMeasurement import ZeroMeasurement +from swarmauri.utils._lazy_import import _lazy_import + +# List of measurements names (file names without the ".py" extension) and corresponding class names +measurements_files = [ + ( + "swarmauri.measurements.concrete.CompletenessMeasurement", + "CompletenessMeasurement", + ), + ( + "swarmauri.measurements.concrete.DistinctivenessMeasurement", + "DistinctivenessMeasurement", + ), + ( + "swarmauri.measurements.concrete.FirstImpressionMeasurement", + "FirstImpressionMeasurement", + ), + ("swarmauri.measurements.concrete.MeanMeasurement", "MeanMeasurement"), + ("swarmauri.measurements.concrete.MiscMeasurement", "MiscMeasurement"), + ( + "swarmauri.measurements.concrete.MissingnessMeasurement", + "MissingnessMeasurement", + ), + ( + "swarmauri.measurements.concrete.PatternMatchingMeasurement", + "PatternMatchingMeasurement", + ), + ( + "swarmauri.measurements.concrete.RatioOfSumsMeasurement", + "RatioOfSumsMeasurement", + ), + ("swarmauri.measurements.concrete.StaticMeasurement", "StaticMeasurement"), + ("swarmauri.measurements.concrete.UniquenessMeasurement", "UniquenessMeasurement"), + ("swarmauri.measurements.concrete.ZeroMeasurement", "ZeroMeasurement"), +] + +# Lazy loading of measurements classes, storing them in variables +for module_name, class_name in measurements_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded measurements classes to __all__ +__all__ = [class_name for _, class_name in measurements_files] diff --git a/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py b/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py index 5c619ecc8..716bd57c5 100644 --- a/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/messages/concrete/__init__.py @@ -1,4 +1,16 @@ -from swarmauri.messages.concrete.HumanMessage import HumanMessage -from swarmauri.messages.concrete.AgentMessage import AgentMessage -from swarmauri.messages.concrete.FunctionMessage import FunctionMessage -from swarmauri.messages.concrete.SystemMessage import SystemMessage +from swarmauri.utils._lazy_import import _lazy_import + +# List of messages names (file names without the ".py" extension) and corresponding class names +messages_files = [ + ("swarmauri.messages.concrete.HumanMessage", "HumanMessage"), + ("swarmauri.messages.concrete.AgentMessage", "AgentMessage"), + ("from swarmauri.messages.concrete.FunctionMessage", "FunctionMessage"), + ("swarmauri.messages.concrete.SystemMessage", "SystemMessage"), +] + +# Lazy loading of messages classes, storing them in variables +for module_name, class_name in messages_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded messages classes to __all__ +__all__ = [class_name for _, class_name in messages_files] diff --git a/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py b/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py index 45b1c7640..fb730763f 100644 --- a/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/parsers/concrete/__init__.py @@ -1,37 +1,29 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of parser names (file names without the ".py" extension) -parser_files = [ - "BeautifulSoupElementParser", - "BERTEmbeddingParser", - "CSVParser", - "EntityRecognitionParser", - "HTMLTagStripParser", - "KeywordExtractorParser", - "Md2HtmlParser", - "OpenAPISpecParser", - "PhoneNumberExtractorParser", - "PythonParser", - "RegExParser", - "TextBlobNounParser", - "TextBlobSentenceParser", - "URLExtractorParser", - "XMLParser", +# List of parsers names (file names without the ".py" extension) and corresponding class names +parsers_files = [ + ( + "swarmauri.parsers.concrete.BeautifulSoupElementParser", + "BeautifulSoupElementParser", + ), + ("swarmauri.parsers.concrete.CSVParser", "CSVParser"), + ("swarmauri.parsers.concrete.HTMLTagStripParser", "HTMLTagStripParser"), + ("swarmauri.parsers.concrete.KeywordExtractorParser", "KeywordExtractorParser"), + ("swarmauri.parsers.concrete.Md2HtmlParser", "Md2HtmlParser"), + ("swarmauri.parsers.concrete.OpenAPISpecParser", "OpenAPISpecParser"), + ( + "swarmauri.parsers.concrete.PhoneNumberExtractorParser", + "PhoneNumberExtractorParser", + ), + ("swarmauri.parsers.concrete.PythonParser", "PythonParser"), + ("swarmauri.parsers.concrete.RegExParser", "RegExParser"), + ("swarmauri.parsers.concrete.URLExtractorParser", "URLExtractorParser"), + ("swarmauri.parsers.concrete.XMLParser", "XMLParser"), ] -# Lazy loading of parser modules, storing them in variables -for parser in parser_files: - globals()[parser] = _lazy_import(f"swarmauri.parsers.concrete.{parser}", parser) +# Lazy loading of parsers classes, storing them in variables +for module_name, class_name in parsers_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded parser modules to __all__ -__all__ = parser_files +# Adding the lazy-loaded parsers classes to __all__ +__all__ = [class_name for _, class_name in parsers_files] diff --git a/pkgs/swarmauri/swarmauri/pipelines/base/PipelineBase.py b/pkgs/swarmauri/swarmauri/pipelines/base/PipelineBase.py new file mode 100644 index 000000000..6cc302c2b --- /dev/null +++ b/pkgs/swarmauri/swarmauri/pipelines/base/PipelineBase.py @@ -0,0 +1,105 @@ +from typing import Any, Callable, List, Optional, Dict +from pydantic import BaseModel, ConfigDict, Field +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.pipelines.IPipeline import IPipeline, PipelineStatus +import uuid + + +class PipelineBase(IPipeline, ComponentBase): + """ + Base class providing default behavior for task orchestration, + error handling, and result aggregation. + """ + + resource: Optional[str] = Field(default=ResourceTypes.PIPELINE.value, frozen=True) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: str = "PipelineBase" + + # Pydantic model fields + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + tasks: List[Dict[str, Any]] = Field(default_factory=list) + parallel: bool = Field(default=False) + + def __init__( + self, tasks: Optional[List[Dict[str, Any]]] = None, parallel: bool = False + ): + """ + Initialize the pipeline. + + :param tasks: Optional list of tasks to initialize pipeline with + :param parallel: Flag to indicate parallel or sequential execution + """ + super().__init__() + self.tasks = tasks or [] + self._results: List[Any] = [] + self._status: PipelineStatus = PipelineStatus.PENDING + self.parallel = parallel + + def add_task(self, task: Callable, *args: Any, **kwargs: Any) -> None: + """ + Add a task to the pipeline. + + :param task: Callable task to be executed + :param args: Positional arguments for the task + :param kwargs: Keyword arguments for the task + """ + task_entry = {"callable": task, "args": args, "kwargs": kwargs} + self.tasks.append(task_entry) + + def execute(self, *args: Any, **kwargs: Any) -> List[Any]: + """ + Execute pipeline tasks. + + :return: List of results from pipeline execution + """ + try: + self._status = PipelineStatus.RUNNING + self._results = [] + + if self.parallel: + # Implement parallel execution logic + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + task["callable"], *task["args"], **task["kwargs"] + ) + for task in self.tasks + ] + self._results = [future.result() for future in futures] + else: + # Sequential execution + for task in self.tasks: + result = task["callable"](*task["args"], **task["kwargs"]) + self._results.append(result) + + self._status = PipelineStatus.COMPLETED + return self._results + + except Exception as e: + self._status = PipelineStatus.FAILED + raise RuntimeError(f"Pipeline execution failed: {e}") + + def get_status(self) -> PipelineStatus: + """ + Get the current status of the pipeline. + + :return: Current pipeline status + """ + return self._status + + def reset(self) -> None: + """ + Reset the pipeline to its initial state. + """ + self._results = [] + self._status = PipelineStatus.PENDING + + def get_results(self) -> List[Any]: + """ + Get the results of the pipeline execution. + + :return: List of results + """ + return self._results diff --git a/pkgs/swarmauri/swarmauri/pipelines/base/__init__.py b/pkgs/swarmauri/swarmauri/pipelines/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/pipelines/concrete/Pipeline.py b/pkgs/swarmauri/swarmauri/pipelines/concrete/Pipeline.py new file mode 100644 index 000000000..82c4ec593 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/pipelines/concrete/Pipeline.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, List, Optional, Dict +from swarmauri.pipelines.base.PipelineBase import PipelineBase + + +class Pipeline(PipelineBase): + """ + Concrete implementation of a pipeline with additional + customization options. + """ + + type: str = "Pipeline" + + def __init__( + self, + tasks: Optional[List[Dict[str, Any]]] = None, + parallel: bool = False, + error_handler: Optional[Callable[[Exception], Any]] = None, + ): + """ + Initialize a customizable pipeline. + + :param tasks: Optional list of tasks to initialize pipeline with + :param parallel: Flag to indicate parallel or sequential execution + :param error_handler: Optional custom error handling function + """ + super().__init__(tasks, parallel) + self._error_handler = error_handler + + def execute(self, *args: Any, **kwargs: Any) -> List[Any]: + """ + Execute pipeline with optional custom error handling. + + :return: List of results from pipeline execution + """ + try: + return super().execute(*args, **kwargs) + except Exception as e: + if self._error_handler: + return [self._error_handler(e)] + raise + + def with_error_handler(self, handler: Callable[[Exception], Any]) -> "Pipeline": + """ + Add a custom error handler to the pipeline. + + :param handler: Error handling function + :return: Current pipeline instance + """ + self._error_handler = handler + return self diff --git a/pkgs/swarmauri/swarmauri/pipelines/concrete/__init__.py b/pkgs/swarmauri/swarmauri/pipelines/concrete/__init__.py new file mode 100644 index 000000000..db1142f92 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/pipelines/concrete/__init__.py @@ -0,0 +1,13 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of pipeline names (file names without the ".py" extension) and corresponding class names +pipeline_files = [ + ("swarmauri.pipelines.concrete.Pipeline", "Pipeline"), +] + +# Lazy loading of pipeline classes, storing them in variables +for module_name, class_name in pipeline_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded pipeline classes to __all__ +__all__ = [class_name for _, class_name in pipeline_files] diff --git a/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py b/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py index 00d6b3cb9..3755b609f 100644 --- a/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/prompts/concrete/__init__.py @@ -1,4 +1,16 @@ -from swarmauri.prompts.concrete.Prompt import Prompt -from swarmauri.prompts.concrete.PromptGenerator import PromptGenerator -from swarmauri.prompts.concrete.PromptMatrix import PromptMatrix -from swarmauri.prompts.concrete.PromptTemplate import PromptTemplate +from swarmauri.utils._lazy_import import _lazy_import + +# List of prompts names (file names without the ".py" extension) and corresponding class names +prompts_files = [ + ("swarmauri.prompts.concrete.Prompt", "Prompt"), + ("swarmauri.prompts.concrete.PromptGenerator", "PromptGenerator"), + ("swarmauri.prompts.concrete.PromptMatrix", "PromptMatrix"), + ("from swarmauri.prompts.concrete.PromptTemplate", "PromptTemplate"), +] + +# Lazy loading of prompts classes, storing them in variables +for module_name, class_name in prompts_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded prompts classes to __all__ +__all__ = [class_name for _, class_name in prompts_files] diff --git a/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py b/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py index c608d8c11..65044d64d 100644 --- a/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/schema_converters/concrete/__init__.py @@ -1,29 +1,37 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of schema converter names (file names without the ".py" extension) -schema_converter_files = [ - "AnthropicSchemaConverter", - "CohereSchemaConverter", - "GeminiSchemaConverter", - "GroqSchemaConverter", - "MistralSchemaConverter", - "OpenAISchemaConverter", - "ShuttleAISchemaConverter", +# List of schema_converters names (file names without the ".py" extension) and corresponding class names +schema_converters_files = [ + ( + "swarmauri.schema_converters.concrete.AnthropicSchemaConverter", + "AnthropicSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.CohereSchemaConverter", + "CohereSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.GeminiSchemaConverter", + "GeminiSchemaConverter", + ), + ("swarmauri.schema_converters.concrete.GroqSchemaConverter", "GroqSchemaConverter"), + ( + "swarmauri.schema_converters.concrete.MistralSchemaConverter", + "MistralSchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.OpenAISchemaConverter", + "OpenAISchemaConverter", + ), + ( + "swarmauri.schema_converters.concrete.ShuttleAISchemaConverter", + "ShuttleAISchemaConverter", + ), ] -# Lazy loading of schema converters, storing them in variables -for schema_converter in schema_converter_files: - globals()[schema_converter] = _lazy_import(f"swarmauri.schema_converters.concrete.{schema_converter}", schema_converter) +# Lazy loading of schema_converters classes, storing them in variables +for module_name, class_name in schema_converters_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded schema converters to __all__ -__all__ = schema_converter_files +# Adding the lazy-loaded schema_converters classes to __all__ +__all__ = [class_name for _, class_name in schema_converters_files] diff --git a/pkgs/swarmauri/swarmauri/service_registries/__init__.py b/pkgs/swarmauri/swarmauri/service_registries/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/service_registries/base/ServiceRegistryBase.py b/pkgs/swarmauri/swarmauri/service_registries/base/ServiceRegistryBase.py new file mode 100644 index 000000000..a7e567a88 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/service_registries/base/ServiceRegistryBase.py @@ -0,0 +1,60 @@ +from typing import Dict, Any, List, Literal, Optional + +from pydantic import ConfigDict, Field +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.service_registries.IServiceRegistry import IServiceRegistry + + +class ServiceRegistryBase(IServiceRegistry, ComponentBase): + """ + Concrete implementation of the IServiceRegistry abstract base class. + """ + + services: Dict[str, Any] = {} + type: Literal["ServiceRegistryBase"] = "ServiceRegistryBase" + resource: Optional[str] = Field( + default=ResourceTypes.SERVICE_REGISTRY.value, frozen=True + ) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + def register_service(self, name: str, details: Dict[str, Any]) -> None: + """ + Register a new service with the given name and details. + """ + self.services[name] = details + + def get_service(self, name: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a service by its name. + """ + return self.services.get(name) + + def get_services_by_roles(self, roles: List[str]) -> List[str]: + """ + Get services filtered by their roles. + """ + return [ + name + for name, details in self.services.items() + if details.get("role") in roles + ] + + def unregister_service(self, name: str) -> None: + """ + unregister the service with the given name. + """ + if name in self.services: + del self.services[name] + print(f"Service {name} unregistered.") + else: + raise ValueError(f"Service {name} not found.") + + def update_service(self, name: str, details: Dict[str, Any]) -> None: + """ + Update the details of the service with the given name. + """ + if name in self.services: + self.services[name].update(details) + print(f"Service {name} updated with new details: {details}") + else: + raise ValueError(f"Service {name} not found.") diff --git a/pkgs/swarmauri/swarmauri/service_registries/base/__init__.py b/pkgs/swarmauri/swarmauri/service_registries/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/service_registries/concrete/ServiceRegistry.py b/pkgs/swarmauri/swarmauri/service_registries/concrete/ServiceRegistry.py new file mode 100644 index 000000000..d61d34452 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/service_registries/concrete/ServiceRegistry.py @@ -0,0 +1,10 @@ +from typing import Literal +from swarmauri.service_registries.base.ServiceRegistryBase import ServiceRegistryBase + + +class ServiceRegistry(ServiceRegistryBase): + """ + Concrete implementation of the ServiceRegistryBase. + """ + + type: Literal["ServiceRegistry"] = "ServiceRegistry" diff --git a/pkgs/swarmauri/swarmauri/service_registries/concrete/__init__.py b/pkgs/swarmauri/swarmauri/service_registries/concrete/__init__.py new file mode 100644 index 000000000..09ec2f608 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/service_registries/concrete/__init__.py @@ -0,0 +1,16 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of service_registry name (file names without the ".py" extension) and corresponding class names +service_registry_files = [ + ( + "swarmauri.service_registries.concrete.ServiceRegistry", + "ServiceRegistry", + ), +] + +# Lazy loading of service_registry classes, storing them in variables +for module_name, class_name in service_registry_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded service_registry classes to __all__ +__all__ = [class_name for _, class_name in service_registry_files] diff --git a/pkgs/swarmauri/swarmauri/state/__init__.py b/pkgs/swarmauri/swarmauri/state/__init__.py new file mode 100644 index 000000000..67f09d1f4 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/state/__init__.py @@ -0,0 +1 @@ +from swarmauri.state.concrete import * diff --git a/pkgs/swarmauri/swarmauri/state/base/StateBase.py b/pkgs/swarmauri/swarmauri/state/base/StateBase.py new file mode 100644 index 000000000..03aa0f544 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/state/base/StateBase.py @@ -0,0 +1,47 @@ +from typing import Dict, Any, Optional, Literal +from pydantic import Field, ConfigDict +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.state.IState import IState + + +class StateBase(IState, ComponentBase): + """ + Abstract base class for state management, extending IState and ComponentBase. + """ + + state_data: Dict[str, Any] = Field( + default_factory=dict, description="The current state data." + ) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + resource: Optional[str] = Field(default=ResourceTypes.STATE.value, frozen=True) + type: Literal["StateBase"] = "StateBase" + + def read(self) -> Dict[str, Any]: + """ + Reads and returns the current state as a dictionary. + """ + raise NotImplementedError("Subclasses must implement 'read'.") + + def write(self, data: Dict[str, Any]) -> None: + """ + Replaces the current state with the given data. + """ + raise NotImplementedError("Subclasses must implement 'write'.") + + def update(self, data: Dict[str, Any]) -> None: + """ + Updates the state with the given data. + """ + raise NotImplementedError("Subclasses must implement 'update'.") + + def reset(self) -> None: + """ + Resets the state to its initial state. + """ + raise NotImplementedError("Subclasses must implement 'reset'.") + + def deep_copy(self) -> "IState": + """ + Creates a deep copy of the current state. + """ + raise NotImplementedError("Subclasses must implement 'deep_copy'.") diff --git a/pkgs/swarmauri/swarmauri/state/base/__init__.py b/pkgs/swarmauri/swarmauri/state/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/state/concrete/DictState.py b/pkgs/swarmauri/swarmauri/state/concrete/DictState.py new file mode 100644 index 000000000..1f5c6adee --- /dev/null +++ b/pkgs/swarmauri/swarmauri/state/concrete/DictState.py @@ -0,0 +1,50 @@ +from typing import Dict, Any +from copy import deepcopy +from pydantic import Field, model_validator +from swarmauri.state.base.StateBase import StateBase + + +class DictState(StateBase): + """ + A concrete implementation of StateBase that manages state as a dictionary. + """ + + state_data: Dict[str, Any] = Field( + default_factory=dict, description="The current state data." + ) + + def read(self) -> Dict[str, Any]: + """ + Reads and returns the current state as a dictionary. + """ + return deepcopy(self.state_data) + + def write(self, data: Dict[str, Any]) -> None: + """ + Replaces the current state with the given data. + """ + self.state_data = deepcopy(data) + + def update(self, data: Dict[str, Any]) -> None: + """ + Updates the state with the given data. + """ + self.state_data.update(data) + + def reset(self) -> None: + """ + Resets the state to an empty dictionary. + """ + self.state_data = {} + + def deep_copy(self) -> "DictState": + """ + Creates a deep copy of the current state. + """ + return DictState(state_data=deepcopy(self.state_data)) + + @model_validator(mode="after") + def _ensure_deep_copy(self): + # Ensures that state_data is always a deep copy + self.state_data = deepcopy(self.state_data) + return self diff --git a/pkgs/swarmauri/swarmauri/state/concrete/__init__.py b/pkgs/swarmauri/swarmauri/state/concrete/__init__.py new file mode 100644 index 000000000..3752dc4af --- /dev/null +++ b/pkgs/swarmauri/swarmauri/state/concrete/__init__.py @@ -0,0 +1,13 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of state names (file names without the ".py" extension) and corresponding class names +state_files = [ + ("swarmauri.state.concrete.DictState", "DictState"), +] + +# Lazy loading of state classes, storing them in variables +for module_name, class_name in state_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded state classes to __all__ +__all__ = [class_name for _, class_name in state_files] diff --git a/pkgs/swarmauri/swarmauri/swarms/__init__.py b/pkgs/swarmauri/swarmauri/swarms/__init__.py index 97c140f08..e69de29bb 100644 --- a/pkgs/swarmauri/swarmauri/swarms/__init__.py +++ b/pkgs/swarmauri/swarmauri/swarms/__init__.py @@ -1 +0,0 @@ -from swarmauri.swarms.concrete import * diff --git a/pkgs/swarmauri/swarmauri/swarms/base/SwarmBase.py b/pkgs/swarmauri/swarmauri/swarms/base/SwarmBase.py new file mode 100644 index 000000000..88ba5e968 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/swarms/base/SwarmBase.py @@ -0,0 +1,101 @@ +import asyncio +from typing import Any, Dict, List, Literal, Optional, Union +from pydantic import ConfigDict, Field +from enum import Enum +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.swarms.ISwarm import ISwarm + + +class SwarmStatus(Enum): + IDLE = "IDLE" + WORKING = "WORKING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class SwarmBase(ISwarm, ComponentBase): + """Base class for Swarm implementations""" + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + resource: Optional[str] = Field(default=ResourceTypes.SWARM.value, frozen=True) + type: Literal["SwarmBase"] = "SwarmBase" + + num_agents: int = Field(default=5, gt=0, le=100) + agent_timeout: float = Field(default=1.0, gt=0) + max_retries: int = Field(default=3, ge=0) + max_queue_size: int = Field(default=10, gt=0) + + _agents: List[Any] = [] + _task_queue: Optional[asyncio.Queue] = None + _status: Dict[int, SwarmStatus] = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._task_queue = asyncio.Queue(maxsize=self.max_queue_size) + self._initialize_agents() + + def _initialize_agents(self): + self._agents = [self._create_agent() for _ in range(self.num_agents)] + self._status = {i: SwarmStatus.IDLE for i in range(self.num_agents)} + + def _create_agent(self) -> Any: + """create specific agent types""" + raise NotImplementedError("Agent creation method not implemented") + + @property + def agents(self) -> List[Any]: + return self._agents + + @property + def queue_size(self) -> int: + return self._task_queue.qsize() + + def get_swarm_status(self) -> Dict[int, SwarmStatus]: + return self._status + + async def _process_task(self, agent_id: int, task: Any, **kwargs) -> Any: + self._status[agent_id] = SwarmStatus.WORKING + try: + for _ in range(self.max_retries): + try: + result = await asyncio.wait_for( + self._execute_task(task, agent_id, **kwargs), + timeout=self.agent_timeout, + ) + self._status[agent_id] = SwarmStatus.COMPLETED + return result + except asyncio.TimeoutError: + continue + self._status[agent_id] = SwarmStatus.FAILED + return None + except Exception as e: + self._status[agent_id] = SwarmStatus.FAILED + raise e + + async def _execute_task(self, task: Any, agent_id: int) -> Any: + """Override this method to implement specific task execution logic""" + raise NotImplementedError("Task execution method not implemented") + + async def exec( + self, input_data: Union[List[str], Any] = [], **kwargs: Optional[Dict] + ) -> List[Any]: + tasks = input_data if isinstance(input_data, list) else [input_data] + for task in tasks: + await self._task_queue.put(task) + + results = [] + while not self._task_queue.empty(): + available_agents = [ + i for i, status in self._status.items() if status == SwarmStatus.IDLE + ] + if not available_agents: + await asyncio.sleep(0.1) + continue + + task = await self._task_queue.get() + agent_id = available_agents[0] + result = await self._process_task(agent_id, task, **kwargs) + if result is not None: + results.append(result) + + return results diff --git a/pkgs/swarmauri/swarmauri/swarms/base/SwarmComponentBase.py b/pkgs/swarmauri/swarmauri/swarms/base/SwarmComponentBase.py deleted file mode 100644 index 8e643494a..000000000 --- a/pkgs/swarmauri/swarmauri/swarms/base/SwarmComponentBase.py +++ /dev/null @@ -1,15 +0,0 @@ -from swarmauri_core.swarms.ISwarmComponent import ISwarmComponent - -class SwarmComponentBase(ISwarmComponent): - """ - Interface for defining basics of any component within the swarm system. - """ - def __init__(self, key: str, name: str, superclass: str, module: str, class_name: str, args=None, kwargs=None): - self.key = key - self.name = name - self.superclass = superclass - self.module = module - self.class_name = class_name - self.args = args or [] - self.kwargs = kwargs or {} - \ No newline at end of file diff --git a/pkgs/swarmauri/swarmauri/swarms/concrete/SimpleSwarmFactory.py b/pkgs/swarmauri/swarmauri/swarms/concrete/SimpleSwarmFactory.py deleted file mode 100644 index beaec4e49..000000000 --- a/pkgs/swarmauri/swarmauri/swarms/concrete/SimpleSwarmFactory.py +++ /dev/null @@ -1,50 +0,0 @@ -import json -import pickle -from typing import List -from swarmauri_core.chains.ISwarmFactory import ( - ISwarmFactory , - CallableChainItem, - AgentDefinition, - FunctionDefinition -) -class SimpleSwarmFactory(ISwarmFactory): - def __init__(self): - self.swarms = [] - self.callable_chains = [] - - def create_swarm(self, agents=[]): - swarm = {"agents": agents} - self.swarms.append(swarm) - return swarm - - def create_agent(self, agent_definition: AgentDefinition): - # For simplicity, agents are stored in a list - # Real-world usage might involve more sophisticated management and instantiation based on type and configuration - agent = {"definition": agent_definition._asdict()} - self.agents.append(agent) - return agent - - def create_callable_chain(self, chain_definition: List[CallableChainItem]): - chain = {"definition": [item._asdict() for item in chain_definition]} - self.callable_chains.append(chain) - return chain - - def register_function(self, function_definition: FunctionDefinition): - if function_definition.identifier in self.functions: - raise ValueError(f"Function {function_definition.identifier} is already registered.") - - self.functions[function_definition.identifier] = function_definition - - def export_configuration(self, format_type: str = 'json'): - # Now exporting both swarms and callable chains - config = {"swarms": self.swarms, "callable_chains": self.callable_chains} - if format_type == "json": - return json.dumps(config) - elif format_type == "pickle": - return pickle.dumps(config) - - def load_configuration(self, config_data, format_type: str = 'json'): - # Loading both swarms and callable chains - config = json.loads(config_data) if format_type == "json" else pickle.loads(config_data) - self.swarms = config.get("swarms", []) - self.callable_chains = config.get("callable_chains", []) \ No newline at end of file diff --git a/pkgs/swarmauri/swarmauri/swarms/concrete/Swarm.py b/pkgs/swarmauri/swarmauri/swarms/concrete/Swarm.py new file mode 100644 index 000000000..2c4c7fee8 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/swarms/concrete/Swarm.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, List, Literal, Optional, Type, Union +from pydantic import Field + +from swarmauri.swarms.base.SwarmBase import SwarmBase, SwarmStatus + + +class Swarm(SwarmBase): + """Concrete implementation of SwarmBase for task processing""" + + type: Literal["Swarm"] = "Swarm" + agent_class: Type[Any] = Field(description="Agent class to use for swarm") + task_batch_size: int = Field(default=1, gt=0) + + def _create_agent(self) -> Any: + """Create new agent instance""" + return self.agent_class() + + async def _execute_task(self, task: Any, agent_id: int, **kwargs) -> Dict[str, Any]: + """Execute task using specified agent""" + agent = self._agents[agent_id] + try: + result = await agent.process(task, **kwargs) + return { + "agent_id": agent_id, + "status": SwarmStatus.COMPLETED, + "result": result, + } + except Exception as e: + return {"agent_id": agent_id, "status": SwarmStatus.FAILED, "error": str(e)} + + async def exec( + self, input_data: Union[List[str], Any] = [], **kwargs: Optional[Dict] + ) -> List[Dict[str, Any]]: + """Execute tasks in parallel using available agents""" + results = await super().exec(input_data, **kwargs) + return results diff --git a/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py b/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py index bd32d1999..98af9eb31 100644 --- a/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/swarms/concrete/__init__.py @@ -1 +1,13 @@ -from swarmauri.swarms.concrete.SimpleSwarmFactory import SimpleSwarmFactory +from swarmauri.utils._lazy_import import _lazy_import + +# List of swarms names (file names without the ".py" extension) and corresponding class names +swarms_files = [ + ("swarmauri.swarms.concrete.Swarm", "Swarm") +] + +# Lazy loading of swarms classes, storing them in variables +for module_name, class_name in swarms_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded swarms classes to __all__ +__all__ = [class_name for _, class_name in swarms_files] diff --git a/pkgs/swarmauri/swarmauri/task_mgt_strategies/__init__.py b/pkgs/swarmauri/swarmauri/task_mgt_strategies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/task_mgt_strategies/base/TaskMgtStrategyBase.py b/pkgs/swarmauri/swarmauri/task_mgt_strategies/base/TaskMgtStrategyBase.py new file mode 100644 index 000000000..698751f71 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/task_mgt_strategies/base/TaskMgtStrategyBase.py @@ -0,0 +1,63 @@ +from abc import abstractmethod + +from pydantic import ConfigDict, Field +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.task_mgt_strategies.ITaskMgtStrategy import ITaskMgtStrategy +from typing import Any, Callable, Dict, Literal, Optional + + +class TaskMgtStrategyBase(ITaskMgtStrategy, ComponentBase): + """Base class for TaskStrategy.""" + + type: Literal["TaskMgtStrategyBase"] = "TaskMgtStrategyBase" + resource: Optional[str] = Field( + default=ResourceTypes.TASK_MGT_STRATEGY.value, frozen=True + ) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + @abstractmethod + def assign_task( + self, task: Dict[str, Any], agent_factory: Callable, service_registry: Callable + ) -> str: + """ + Abstract method to assign a task to a service. + """ + raise NotImplementedError( + "assign_task method must be implemented in derived classes." + ) + + @abstractmethod + def add_task(self, task: Dict[str, Any]) -> None: + """ + Abstract method to add a task to the task queue. + """ + raise NotImplementedError( + "add_task method must be implemented in derived classes." + ) + + @abstractmethod + def remove_task(self, task_id: str) -> None: + """ + Abstract method to remove a task from the task queue. + """ + raise NotImplementedError( + "remove_task method must be implemented in derived classes." + ) + + @abstractmethod + def get_task(self, task_id: str) -> Dict[str, Any]: + """ + Abstract method to get a task from the task queue. + """ + raise NotImplementedError( + "get_task method must be implemented in derived classes." + ) + + @abstractmethod + def process_tasks(self, task: Dict[str, Any]) -> None: + """ + Abstract method to process tasks. + """ + raise NotImplementedError( + "process_task method must be implemented in derived classes." + ) diff --git a/pkgs/swarmauri/swarmauri/task_mgt_strategies/base/__init__.py b/pkgs/swarmauri/swarmauri/task_mgt_strategies/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/RoundRobinStrategy.py b/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/RoundRobinStrategy.py new file mode 100644 index 000000000..d04c5a442 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/RoundRobinStrategy.py @@ -0,0 +1,73 @@ +from typing import Callable, Dict, Any, List +from swarmauri.task_mgt_strategies.base.TaskMgtStrategyBase import TaskMgtStrategyBase +from queue import Queue +import logging + + +class RoundRobinStrategy(TaskMgtStrategyBase): + """Round-robin task assignment strategy.""" + + task_queue: Queue = Queue() # Synchronous task queue for incoming tasks + task_assignments: Dict[str, str] = {} # Tracks task assignments + current_index: int = 0 # Tracks the next service to assign tasks to + + def assign_task(self, task: Dict[str, Any], service_registry: Callable[[], List[str]]) -> None: + """ + Assign a task to a service using the round-robin strategy. + :param task: Task metadata and payload. + :param service_registry: Callable that returns available services. + """ + available_services = service_registry() + if not available_services: + raise ValueError("No services available for task assignment.") + + # Select the service based on the round-robin index + service = available_services[self.current_index % len(available_services)] + self.task_assignments[task["task_id"]] = service + self.current_index += 1 + logging.info(f"Task '{task['task_id']}' assigned to service '{service}'.") + + def add_task(self, task: Dict[str, Any]) -> None: + """ + Add a task to the task queue. + :param task: Task metadata and payload. + """ + self.task_queue.put(task) + + def remove_task(self, task_id: str) -> None: + """ + Remove a task from the task registry. + :param task_id: Unique identifier of the task to remove. + """ + if task_id in self.task_assignments: + del self.task_assignments[task_id] + logging.info(f"Task '{task_id}' removed from assignments.") + else: + raise ValueError(f"Task '{task_id}' not found in assignments.") + + def get_task(self, task_id: str) -> Dict[str, Any]: + """ + Get a task's assigned service. + :param task_id: Unique identifier of the task. + :return: Task assignment details. + """ + if task_id in self.task_assignments: + service = self.task_assignments[task_id] + return {"task_id": task_id, "assigned_service": service} + else: + raise ValueError(f"Task '{task_id}' not found in assignments.") + + def process_tasks(self, service_registry: Callable[[], List[str]], transport: Callable) -> None: + """ + Process tasks from the task queue and assign them to services. + :param service_registry: Callable that returns available services. + :param transport: Callable used to send tasks to assigned services. + """ + while not self.task_queue.empty(): + task = self.task_queue.get() + try: + self.assign_task(task, service_registry) + assigned_service = self.task_assignments[task["task_id"]] + transport.send(task, assigned_service) + except ValueError as e: + raise ValueError(f"Error assigning task: {e}") diff --git a/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/__init__.py b/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/__init__.py new file mode 100644 index 000000000..3ee5bec23 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/task_mgt_strategies/concrete/__init__.py @@ -0,0 +1,13 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of task_mgt_strategies names (file names without the ".py" extension) and corresponding class names +task_mgt_strategies_files = [ + ("swarmauri.task_mgt_strategies.concrete.RoundRobinStrategy", "RoundRobinStrategy"), +] + +# Lazy loading of task_mgt_strategies classes, storing them in variables +for module_name, class_name in task_mgt_strategies_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded state classes to __all__ +__all__ = [class_name for _, class_name in task_mgt_strategies_files] diff --git a/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py b/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py index 87127d6bf..a7311c7c9 100644 --- a/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/toolkits/concrete/__init__.py @@ -1,21 +1,4 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - # If class is not found, print a warning message - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils._lazy_import import _lazy_import # List of toolkit names (file names without the ".py" extension) and corresponding class names toolkit_files = [ diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/Parameter.py b/pkgs/swarmauri/swarmauri/tools/concrete/Parameter.py index 2f5a357b6..59afe9a83 100644 --- a/pkgs/swarmauri/swarmauri/tools/concrete/Parameter.py +++ b/pkgs/swarmauri/swarmauri/tools/concrete/Parameter.py @@ -1,7 +1,10 @@ -from typing import Literal, Union -from pydantic import Field +from typing import List, Literal, Union from swarmauri.tools.base.ParameterBase import ParameterBase class Parameter(ParameterBase): - type: Union[Literal["string", "number", "boolean", "array", "object"], str] + type: Union[ + Literal["string", "number", "boolean", "array", "object"], + str, + List[str], + ] diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py b/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py deleted file mode 100644 index 23ce384df..000000000 --- a/pkgs/swarmauri/swarmauri/tools/concrete/SMOGIndexTool.py +++ /dev/null @@ -1,113 +0,0 @@ -from swarmauri_core.typing import SubclassUnion -from typing import List, Literal, Dict -from pydantic import Field -from swarmauri.tools.base.ToolBase import ToolBase -from swarmauri.tools.concrete.Parameter import Parameter -import re -import math -import nltk -from nltk.tokenize import sent_tokenize - -# Download required NLTK data once during module load -nltk.download("punkt", quiet=True) - - -class SMOGIndexTool(ToolBase): - version: str = "0.1.dev2" - parameters: List[Parameter] = Field( - default_factory=lambda: [ - Parameter( - name="text", - type="string", - description="The text to analyze for SMOG Index", - required=True, - ) - ] - ) - name: str = "SMOGIndexTool" - description: str = "Calculates the SMOG Index for the provided text." - type: Literal["SMOGIndexTool"] = "SMOGIndexTool" - - def __call__(self, text: str) -> Dict[str, float]: - """ - Calculates the SMOG Index for the provided text. - - Parameters: - text (str): The text to analyze. - - Returns: - float: The calculated SMOG Index. - """ - return {"smog_index": self.calculate_smog_index(text)} - - def calculate_smog_index(self, text: str) -> float: - """ - Calculate the SMOG Index for a given text. - - Parameters: - text (str): The text to analyze. - - Returns: - float: The calculated SMOG Index. - """ - sentences = self.count_sentences(text) - polysyllables = self.count_polysyllables(text) - - if sentences == 0: - return 0.0 # Avoid division by zero - - smog_index = 1.0430 * math.sqrt(polysyllables * (30 / sentences)) + 3.1291 - return round(smog_index, 1) - - def count_sentences(self, text: str) -> int: - """ - Count the number of sentences in the text. - - Parameters: - text (str): The text to analyze. - - Returns: - int: The number of sentences in the text. - """ - sentences = sent_tokenize(text) - return len(sentences) - - def count_polysyllables(self, text: str) -> int: - """ - Count the number of polysyllabic words (words with three or more syllables) in the text. - - Parameters: - text (str): The text to analyze. - - Returns: - int: The number of polysyllabic words in the text. - """ - words = re.findall(r"\w+", text) - return len([word for word in words if self.count_syllables(word) >= 3]) - - def count_syllables(self, word: str) -> int: - """ - Count the number of syllables in a given word. - - Parameters: - word (str): The word to analyze. - - Returns: - int: The number of syllables in the word. - """ - word = word.lower() - vowels = "aeiouy" - count = 0 - if word and word[0] in vowels: - count += 1 - for index in range(1, len(word)): - if word[index] in vowels and word[index - 1] not in vowels: - count += 1 - if word.endswith("e") and not word.endswith("le"): - count -= 1 - if count == 0: - count = 1 - return count - - -SubclassUnion.update(baseclass=ToolBase, type_name="SMOGIndexTool", obj=SMOGIndexTool) diff --git a/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py b/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py index f9d2a297e..173b3ac4b 100644 --- a/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/tools/concrete/__init__.py @@ -1,25 +1,12 @@ -import importlib - -# Define a lazy loader function with a warning message if the module or class is not found -def _lazy_import(module_name, class_name): - try: - # Import the module - module = importlib.import_module(module_name) - # Dynamically get the class from the module - return getattr(module, class_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - except AttributeError: - print(f"Warning: The class '{class_name}' was not found in module '{module_name}'.") - return None +from swarmauri.utils.LazyLoader import LazyLoader # List of tool names (file names without the ".py" extension) and corresponding class names tool_files = [ ("swarmauri.tools.concrete.AdditionTool", "AdditionTool"), - ("swarmauri.tools.concrete.AutomatedReadabilityIndexTool", "AutomatedReadabilityIndexTool"), + ( + "swarmauri.tools.concrete.AutomatedReadabilityIndexTool", + "AutomatedReadabilityIndexTool", + ), ("swarmauri.tools.concrete.CalculatorTool", "CalculatorTool"), ("swarmauri.tools.concrete.CodeExtractorTool", "CodeExtractorTool"), ("swarmauri.tools.concrete.CodeInterpreterTool", "CodeInterpreterTool"), @@ -40,9 +27,9 @@ def _lazy_import(module_name, class_name): ("swarmauri.tools.concrete.WeatherTool", "WeatherTool"), ] -# Lazy loading of tools, storing them in variables +# Lazy loading of tools using LazyLoader for module_name, class_name in tool_files: - globals()[class_name] = _lazy_import(module_name, class_name) + globals()[class_name] = LazyLoader(module_name, class_name) -# Adding the lazy-loaded tools to __all__ +# Adding tools to __all__ (still safe because LazyLoader doesn't raise errors until accessed) __all__ = [class_name for _, class_name in tool_files] diff --git a/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py b/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py index 95900d024..1b6619352 100644 --- a/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/tracing/concrete/__init__.py @@ -1,5 +1,17 @@ -from swarmauri.tracing.concrete.CallableTracer import CallableTracer -from swarmauri.tracing.concrete.ChainTracer import ChainTracer -from swarmauri.tracing.concrete.SimpleTraceContext import SimpleTraceContext -from swarmauri.tracing.concrete.TracedVariable import TracedVariable -from swarmauri.tracing.concrete.VariableTracer import VariableTracer +from swarmauri.utils._lazy_import import _lazy_import + +# List of tracing names (file names without the ".py" extension) and corresponding class names +tracing_files = [ + ("swarmauri.tracing.concrete.CallableTracer", "CallableTracer"), + ("from swarmauri.tracing.concrete.ChainTracer", "ChainTracer"), + ("swarmauri.tracing.concrete.SimpleTraceContext", "SimpleTraceContext"), + ("swarmauri.tracing.concrete.TracedVariable", "TracedVariable"), + ("swarmauri.tracing.concrete.VariableTracer", "VariableTracer"), +] + +# Lazy loading of tracings, storing them in variables +for module_name, class_name in tracing_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded tracings to __all__ +__all__ = [class_name for _, class_name in tracing_files] diff --git a/pkgs/swarmauri/swarmauri/transports/__init__.py b/pkgs/swarmauri/swarmauri/transports/__init__.py new file mode 100644 index 000000000..4de373c3c --- /dev/null +++ b/pkgs/swarmauri/swarmauri/transports/__init__.py @@ -0,0 +1 @@ +from swarmauri.transports.concrete import * diff --git a/pkgs/swarmauri/swarmauri/transports/base/TransportBase.py b/pkgs/swarmauri/swarmauri/transports/base/TransportBase.py new file mode 100644 index 000000000..370ed1c8a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/transports/base/TransportBase.py @@ -0,0 +1,50 @@ +from typing import Dict, Any, List, Optional, Literal +from pydantic import ConfigDict, Field +from enum import Enum, auto +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes +from swarmauri_core.transports.ITransport import ITransport + + +class TransportProtocol(Enum): + """ + Enumeration of transportation protocols supported by the transport layer + """ + + UNICAST = auto() + MULTICAST = auto() + BROADCAST = auto() + PUBSUB = auto() + + +class TransportBase(ITransport, ComponentBase): + allowed_protocols: List[TransportProtocol] = [] + resource: Optional[str] = Field(default=ResourceTypes.TRANSPORT.value, frozen=True) + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["TransportBase"] = "TransportBase" + + def send(self, sender: str, recipient: str, message: Any) -> None: + """ + Send a message to a specific recipient. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("send() not implemented in subclass yet.") + + def broadcast(self, sender: str, message: Any) -> None: + """ + Broadcast a message to all potential recipients. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("broadcast() not implemented in subclass yet.") + + def multicast(self, sender: str, recipients: List[str], message: Any) -> None: + """ + Send a message to multiple specific recipients. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("multicast() not implemented in subclass yet.") diff --git a/pkgs/swarmauri/swarmauri/transports/base/__init__.py b/pkgs/swarmauri/swarmauri/transports/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkgs/swarmauri/swarmauri/transports/concrete/PubSubTransport.py b/pkgs/swarmauri/swarmauri/transports/concrete/PubSubTransport.py new file mode 100644 index 000000000..48303ad6a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/transports/concrete/PubSubTransport.py @@ -0,0 +1,118 @@ +from typing import Dict, Any, List, Optional, Set, Literal +from uuid import uuid4 +import asyncio +from swarmauri.transports.base.TransportBase import TransportBase, TransportProtocol + + +class PubSubTransport(TransportBase): + allowed_protocols: List[TransportProtocol] = [TransportProtocol.PUBSUB] + _topics: Dict[str, Set[str]] = {} # Topic to subscriber mappings + _subscribers: Dict[str, asyncio.Queue] = {} + type: Literal["PubSubTransport"] = "PubSubTransport" + + async def subscribe(self, topic: str) -> str: + """ + Subscribe an agent to a specific topic. + + Args: + topic (str): The topic to subscribe to + + Returns: + str: Unique subscriber ID + """ + subscriber_id = self.id + + # Create message queue for this subscribere + self._subscribers[subscriber_id] = asyncio.Queue() + + # Add subscriber to topic + if topic not in self._topics: + self._topics[topic] = set() + self._topics[topic].add(subscriber_id) + + return subscriber_id + + async def unsubscribe(self, topic: str): + """ + Unsubscribe an agent from a topic. + + Args: + topic (str): The topic to unsubscribe from + subscriber_id (str): Unique identifier of the subscriber + """ + subscriber_id = self.id + if topic in self._topics and subscriber_id in self._topics[topic]: + self._topics[topic].remove(subscriber_id) + + # Optional: Clean up if no subscribers remain + if not self._topics[topic]: + del self._topics[topic] + + async def publish(self, topic: str, message: Any): + """ + Publish a message to a specific topic. + + Args: + topic (str): The topic to publish to + message (Any): The message to be published + """ + if topic not in self._topics: + return + + # Distribute message to all subscribers of this topic + for subscriber_id in self._topics[topic]: + await self._subscribers[subscriber_id].put(message) + + async def receive(self) -> Any: + """ + Receive messages for a specific subscriber. + + Args: + subscriber_id (str): Unique identifier of the subscriber + + Returns: + Any: Received message + """ + return await self._subscribers[self.id].get() + + def send(self, sender: str, recipient: str, message: Any) -> None: + """ + Simulate sending a direct message (not applicable in Pub/Sub context). + + Args: + sender (str): The sender ID + recipient (str): The recipient ID + message (Any): The message to send + + Raises: + NotImplementedError: This method is not applicable for Pub/Sub. + """ + raise NotImplementedError("Direct send not supported in Pub/Sub model.") + + def broadcast(self, sender: str, message: Any) -> None: + """ + Broadcast a message to all subscribers of all topics. + + Args: + sender (str): The sender ID + message (Any): The message to broadcast + """ + for topic in self._topics: + asyncio.create_task(self.publish(topic, message)) + + def multicast(self, sender: str, recipients: List[str], message: Any) -> None: + """ + Send a message to specific topics (acting as recipients). + + Args: + sender (str): The sender ID + recipients (List[str]): Topics to send the message to + message (Any): The message to send + """ + for topic in recipients: + asyncio.create_task(self.publish(topic, message)) + + +check = PubSubTransport() +print(check.type) +print("I am okay") diff --git a/pkgs/swarmauri/swarmauri/transports/concrete/__init__.py b/pkgs/swarmauri/swarmauri/transports/concrete/__init__.py new file mode 100644 index 000000000..b9f008bb5 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/transports/concrete/__init__.py @@ -0,0 +1,13 @@ +from swarmauri.utils._lazy_import import _lazy_import + +# List of transport names (file names without the ".py" extension) and corresponding class names +transport_files = [ + ("swarmauri.transports.concrete.PubSubTransport", "PubSubTransport"), +] + +# Lazy loading of transport classes, storing them in variables +for module_name, class_name in transport_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded transport classes to __all__ +__all__ = [class_name for _, class_name in transport_files] diff --git a/pkgs/swarmauri/swarmauri/utils/LazyLoader.py b/pkgs/swarmauri/swarmauri/utils/LazyLoader.py new file mode 100644 index 000000000..065065861 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/utils/LazyLoader.py @@ -0,0 +1,31 @@ +import importlib + +class LazyLoader: + def __init__(self, module_name, class_name): + self.module_name = module_name + self.class_name = class_name + self._loaded_class = None + + def _load_class(self): + if self._loaded_class is None: + try: + module = importlib.import_module(self.module_name) + self._loaded_class = getattr(module, self.class_name) + except ImportError: + print( + f"Warning: The module '{self.module_name}' is not available. " + f"Please install the necessary dependencies to enable this functionality." + ) + self._loaded_class = None + except AttributeError: + print( + f"Warning: The class '{self.class_name}' was not found in module '{self.module_name}'." + ) + self._loaded_class = None + return self._loaded_class + + def __getattr__(self, item): + loaded_class = self._load_class() + if loaded_class is None: + raise ImportError(f"Unable to load class {self.class_name} from {self.module_name}") + return getattr(loaded_class, item) diff --git a/pkgs/swarmauri/swarmauri/utils/_get_subclasses.py b/pkgs/swarmauri/swarmauri/utils/_get_subclasses.py new file mode 100644 index 000000000..adf100601 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/utils/_get_subclasses.py @@ -0,0 +1,72 @@ +import importlib +import re + + +def get_classes_from_module(module_name: str): + """ + Dynamically imports a module and retrieves a dictionary of class names and their corresponding class objects. + + :param module_name: The name of the module (e.g., "parsers", "agent"). + :return: A dictionary with class names as keys and class objects as values. + """ + # Convert module name to lowercase to ensure consistency + module_name_lower = module_name.lower() + + # Construct the full module path dynamically + full_module_path = f"swarmauri.{module_name_lower}s.concrete" + + try: + # Import the module dynamically + module = importlib.import_module(full_module_path) + + # Get the list of class names from __all__ + class_names = getattr(module, "__all__", []) + + # Create a dictionary with class names and their corresponding class objects + classes_dict = { + class_name: getattr(module, class_name) for class_name in class_names + } + + return classes_dict + except ImportError as e: + print(f"Error importing module {full_module_path}: {e}") + raise ModuleNotFoundError(f"Resource '{module_name}' is not registered.") + except AttributeError as e: + print(f"Error accessing class in {full_module_path}: {e}") + raise e + + +def get_class_from_module(module_name: str, class_name: str): + """ + Dynamically imports a module and retrieves the class name of the module. + + :param module_name: The name of the module (e.g., "parsers", "agent"). + :return: The class name of the module. + """ + # Convert module name to lowercase to ensure consistency + module_name_lower = module_name.lower() + + # Construct the full module path dynamically + full_module_path = f"swarmauri.{module_name_lower}s.concrete" + + try: + # Import the module dynamically + module = importlib.import_module(full_module_path) + + # Get the list of class names from __all__ + class_names = getattr(module, "__all__", []) + + if not class_names: + raise AttributeError(f"No classes found in module {full_module_path}") + + for cls_name in class_names: + if cls_name == class_name: + return getattr(module, class_name) + return None + + except ImportError as e: + print(f"Error importing module {full_module_path}: {e}") + raise ModuleNotFoundError(f"Resource '{module_name}' is not found.") + except AttributeError as e: + print(f"Error accessing class in {full_module_path}: {e}") + raise e diff --git a/pkgs/swarmauri/swarmauri/utils/_lazy_import.py b/pkgs/swarmauri/swarmauri/utils/_lazy_import.py new file mode 100644 index 000000000..a3d3bd34a --- /dev/null +++ b/pkgs/swarmauri/swarmauri/utils/_lazy_import.py @@ -0,0 +1,22 @@ +import importlib + + +# Define a lazy loader function with a warning message if the module or class is not found +def _lazy_import(module_name, class_name): + try: + # Import the module + module = importlib.import_module(module_name) + # Dynamically get the class from the module + return getattr(module, class_name) + except ImportError: + # If module is not available, print a warning message + print( + f"Warning: The module '{module_name}' is not available. " + f"Please install the necessary dependencies to enable this functionality." + ) + return None + except AttributeError: + print( + f"Warning: The class '{class_name}' was not found in module '{module_name}'." + ) + return None diff --git a/pkgs/swarmauri/swarmauri/utils/method_signature_extractor_decorator.py b/pkgs/swarmauri/swarmauri/utils/method_signature_extractor_decorator.py new file mode 100644 index 000000000..20375df83 --- /dev/null +++ b/pkgs/swarmauri/swarmauri/utils/method_signature_extractor_decorator.py @@ -0,0 +1,113 @@ +from typing import ( + List, + Any, + Union, + Optional, + Callable, + get_type_hints, + get_args, + get_origin, +) +import inspect +from functools import wraps +from pydantic import BaseModel +from swarmauri.tools.concrete.Parameter import Parameter + + +class MethodSignatureExtractor(BaseModel): + parameters: List[Parameter] = [] + method: Callable + _type_mapping: dict = { + int: "integer", + float: "number", + str: "string", + bool: "boolean", + list: "array", + dict: "object", + Any: "any", + } + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.parameters = self.extract_signature_details() + + def _python_type_to_json_schema_type(self, py_type): + if get_origin(py_type) is not None: + origin = get_origin(py_type) + args = get_args(py_type) + + if origin is list: + items_type = self._python_type_to_json_schema_type(args[0]) + return {"type": "array", "items": items_type} + elif origin is dict: + return {"type": "object"} + elif origin in (Union, Optional): + if len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + return self._python_type_to_json_schema_type(non_none_type) + return { + "oneOf": [ + self._python_type_to_json_schema_type(arg) for arg in args + ] + } + return {"type": self._type_mapping.get(origin, "string")} + else: + return {"type": self._type_mapping.get(py_type, "string")} + + def extract_signature_details(self): + sig = inspect.signature(self.method) + type_hints = get_type_hints(self.method) + parameters = sig.parameters + details_list = [] + for param_name, param in parameters.items(): + if param_name == "self": + continue + + param_type = type_hints.get(param_name, Any) + param_default = ( + param.default if param.default is not inspect.Parameter.empty else None + ) + required = param.default is inspect.Parameter.empty + enum = None + param_type_json_schema = self._python_type_to_json_schema_type(param_type) + print(param_type_json_schema) + + if "oneOf" in param_type_json_schema: + param_type_json_schema["type"] = [ + type_["type"] for type_ in param_type_json_schema["oneOf"] + ] + + description = f"Parameter {param_name} of type {param_type_json_schema}" + + detail = Parameter( + name=param_name, + type=param_type_json_schema["type"], + description=description, + required=required, + enum=enum, + ) + details_list.append(detail) + + return details_list + + +def extract_method_signature(func: Callable): + """ + A decorator that extracts method signature details and attaches them to the function. + + Args: + func (Callable): The function to extract signature details for. + + Returns: + Callable: The original function with added signature_details attribute. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + extractor = MethodSignatureExtractor(method=func) + + wrapper.signature_details = extractor.parameters + + return wrapper diff --git a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py index 08a36e26c..ceb2b245c 100644 --- a/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/vector_stores/concrete/__init__.py @@ -1,26 +1,14 @@ -import importlib +from swarmauri.utils._lazy_import import _lazy_import -# Define a lazy loader function with a warning message if the module is not found -def _lazy_import(module_name, module_description=None): - try: - return importlib.import_module(module_name) - except ImportError: - # If module is not available, print a warning message - print(f"Warning: The module '{module_description or module_name}' is not available. " - f"Please install the necessary dependencies to enable this functionality.") - return None - -# List of vector store names (file names without the ".py" extension) -vector_store_files = [ - "Doc2VecVectorStore", - "MlmVectorStore", - "SqliteVectorStore", - "TfidfVectorStore", +# List of vectore_stores names (file names without the ".py" extension) and corresponding class names +vectore_stores_files = [ + ("swarmauri.vector_stores.concrete.SqliteVectorStore", "SqliteVectorStore"), + ("swarmauri.vector_stores.concrete.TfidfVectorStore", "TfidfVectorStore"), ] -# Lazy loading of vector stores, storing them in variables -for vector_store in vector_store_files: - globals()[vector_store] = _lazy_import(f"swarmauri.vector_stores.concrete.{vector_store}", vector_store) +# Lazy loading of vectore_storess, storing them in variables +for module_name, class_name in vectore_stores_files: + globals()[class_name] = _lazy_import(module_name, class_name) -# Adding the lazy-loaded vector stores to __all__ -__all__ = vector_store_files +# Adding the lazy-loaded vectore_storess to __all__ +__all__ = [class_name for _, class_name in vectore_stores_files] diff --git a/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py b/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py index 16f348f20..7283bc0a9 100644 --- a/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py +++ b/pkgs/swarmauri/swarmauri/vectors/concrete/__init__.py @@ -1,4 +1,14 @@ -# -*- coding: utf-8 -*- +from swarmauri.utils._lazy_import import _lazy_import -from swarmauri.vectors.concrete.Vector import Vector -from swarmauri.vectors.concrete.VectorProductMixin import VectorProductMixin +# List of vectors names (file names without the ".py" extension) and corresponding class names +vectors_files = [ + ("swarmauri.vectors.concrete.Vector", "Vector"), + ("swarmauri.vectors.concrete.VectorProductMixin", "VectorProductMixin"), +] + +# Lazy loading of vectorss, storing them in variables +for module_name, class_name in vectors_files: + globals()[class_name] = _lazy_import(module_name, class_name) + +# Adding the lazy-loaded vectorss to __all__ +__all__ = [class_name for _, class_name in vectors_files] diff --git a/pkgs/swarmauri/tests/static/credentials.json b/pkgs/swarmauri/tests/static/credentials.json new file mode 100644 index 000000000..a36c5e36b --- /dev/null +++ b/pkgs/swarmauri/tests/static/credentials.json @@ -0,0 +1,5 @@ +{ + "client_id": "Put your client_id here", + "client_secret": "Put your client_secret here", + "redirect_uri": "put your redirect_uri here or use :ietf:wg:oauth:2.0:oob" +} \ No newline at end of file diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 new file mode 100644 index 000000000..1fabb16fa Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test2.mp3 differ diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 new file mode 100644 index 000000000..f8a7a0a8b Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test3.mp3 differ diff --git a/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 b/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 new file mode 100644 index 000000000..194714d16 Binary files /dev/null and b/pkgs/swarmauri/tests/static/hyperbolic_test_tts.mp3 differ diff --git a/pkgs/swarmauri/tests/unit/control_panels/ControlPanel_unit_test.py b/pkgs/swarmauri/tests/unit/control_panels/ControlPanel_unit_test.py new file mode 100644 index 000000000..78d266d46 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/control_panels/ControlPanel_unit_test.py @@ -0,0 +1,194 @@ +import pytest +from unittest.mock import MagicMock +from swarmauri.control_panels.concrete.ControlPanel import ControlPanel +from swarmauri.factories.base.FactoryBase import FactoryBase +from swarmauri.service_registries.base.ServiceRegistryBase import ServiceRegistryBase +from swarmauri.task_mgt_strategies.base.TaskMgtStrategyBase import TaskMgtStrategyBase +from swarmauri.transports.base.TransportBase import TransportBase + +from unittest.mock import MagicMock +from pydantic import BaseModel + +class SerializableMagicMock(MagicMock, BaseModel): + """A MagicMock class that can be serialized using Pydantic.""" + + def dict(self, *args, **kwargs): + """Serialize the mock object to a dictionary.""" + return {"mock_name": self._mock_name, "calls": self.mock_calls} + + def json(self, *args, **kwargs): + """Serialize the mock object to a JSON string.""" + return super().json(*args, **kwargs) + +@pytest.fixture +def control_panel(): + """Fixture to create a fully mocked ControlPanel instance with serializable mocks.""" + + # Create serializable mocks for all dependencies + agent_factory = SerializableMagicMock(spec=FactoryBase) + agent_factory.create_agent = SerializableMagicMock(return_value="MockAgent") + agent_factory.get_agent_by_name = SerializableMagicMock(return_value="MockAgent") + agent_factory.delete_agent = SerializableMagicMock() + agent_factory.get_agents = SerializableMagicMock(return_value=["MockAgent1", "MockAgent2"]) + + service_registry = SerializableMagicMock(spec=ServiceRegistryBase) + service_registry.register_service = SerializableMagicMock() + service_registry.unregister_service = SerializableMagicMock() + service_registry.get_services = SerializableMagicMock(return_value=["service1", "service2"]) + + task_mgt_strategy = SerializableMagicMock(spec=TaskMgtStrategyBase) + task_mgt_strategy.add_task = SerializableMagicMock() + task_mgt_strategy.process_tasks = SerializableMagicMock() + task_mgt_strategy.assign_task = SerializableMagicMock() + + transport = SerializableMagicMock(spec=TransportBase) + + # Return the ControlPanel instance with mocked dependencies + return ControlPanel( + agent_factory=agent_factory, + service_registry=service_registry, + task_mgt_strategy=task_mgt_strategy, + transport=transport, + ) + +def test_create_agent(control_panel): + """Test the create_agent method.""" + agent_name = "agent1" + agent_role = "worker" + agent = MagicMock() + + # Configure mocks + control_panel.agent_factory.create_agent.return_value = agent + + # Call the method + result = control_panel.create_agent(agent_name, agent_role) + + # Assertions + control_panel.agent_factory.create_agent.assert_called_once_with( + agent_name, agent_role + ) + control_panel.service_registry.register_service.assert_called_once_with( + agent_name, {"role": agent_role, "status": "active"} + ) + assert result == "MockAgent" + + +def test_remove_agent(control_panel): + """Test the remove_agent method.""" + agent_name = "agent1" + agent = MagicMock() + + # Configure mocks + control_panel.agent_factory.get_agent_by_name.return_value = agent + + # Call the method + control_panel.remove_agent(agent_name) + + # Assertions + control_panel.agent_factory.get_agent_by_name.assert_called_once_with(agent_name) + control_panel.service_registry.unregister_service.assert_called_once_with( + agent_name + ) + control_panel.agent_factory.delete_agent.assert_called_once_with(agent_name) + + +def test_remove_agent_not_found(control_panel): + """Test remove_agent when the agent is not found.""" + agent_name = "agent1" + + # Configure mocks + control_panel.agent_factory.get_agent_by_name.return_value = None + + # Call the method and expect a ValueError + with pytest.raises(ValueError) as exc_info: + control_panel.remove_agent(agent_name) + assert str(exc_info.value) == f"Agent '{agent_name}' not found." + + +def test_list_active_agents(control_panel): + """Test the list_active_agents method.""" + agent1 = MagicMock() + agent1.name = "agent1" + agent2 = MagicMock() + agent2.name = "agent2" + agents = [agent1, agent2] + + # Configure mocks + control_panel.agent_factory.get_agents.return_value = agents + + # Call the method + result = control_panel.list_active_agents() + + # Assertions + control_panel.agent_factory.get_agents.assert_called_once() + assert result == ["agent1", "agent2"] + + +def test_submit_tasks(control_panel): + """Test the submit_tasks method.""" + task1 = {"task_id": "task1"} + task2 = {"task_id": "task2"} + tasks = [task1, task2] + + # Call the method + control_panel.submit_tasks(tasks) + + # Assertions + calls = [((task1,),), ((task2,),)] + control_panel.task_mgt_strategy.add_task.assert_has_calls(calls) + assert control_panel.task_mgt_strategy.add_task.call_count == 2 + + +def test_process_tasks(control_panel): + """Test the process_tasks method.""" + # Call the method + control_panel.process_tasks() + + # Assertions + control_panel.task_mgt_strategy.process_tasks.assert_called_once_with( + control_panel.service_registry.get_services, control_panel.transport + ) + + +def test_process_tasks_exception(control_panel, caplog): + """Test process_tasks when an exception occurs.""" + # Configure mocks + control_panel.task_mgt_strategy.process_tasks.side_effect = Exception("Test error") + + # Call the method + control_panel.process_tasks() + + # Assertions + control_panel.task_mgt_strategy.process_tasks.assert_called_once_with( + control_panel.service_registry.get_services, control_panel.transport + ) + assert "Error while processing tasks: Test error" in caplog.text + + +def test_distribute_tasks(control_panel): + """Test the distribute_tasks method.""" + task = {"task_id": "task1"} + + # Call the method + control_panel.distribute_tasks(task) + + # Assertions + control_panel.task_mgt_strategy.assign_task.assert_called_once_with( + task, control_panel.service_registry.get_services + ) + + +def test_orchestrate_agents(control_panel): + """Test the orchestrate_agents method.""" + tasks = [{"task_id": "task1"}, {"task_id": "task2"}] + + # Configure mocks + control_panel.submit_tasks = MagicMock() + control_panel.process_tasks = MagicMock() + + # Call the method + control_panel.orchestrate_agents(tasks) + + # Assertions + control_panel.submit_tasks.assert_called_once_with(tasks) + control_panel.process_tasks.assert_called_once() diff --git a/pkgs/swarmauri/tests/unit/dataconnectors/GoogleDriveDataConnector_unit_test.py b/pkgs/swarmauri/tests/unit/dataconnectors/GoogleDriveDataConnector_unit_test.py new file mode 100644 index 000000000..9b1a092da --- /dev/null +++ b/pkgs/swarmauri/tests/unit/dataconnectors/GoogleDriveDataConnector_unit_test.py @@ -0,0 +1,86 @@ +import pytest +from swarmauri.dataconnectors.concrete.GoogleDriveDataConnector import ( + GoogleDriveDataConnector, +) + + +@pytest.fixture(scope="module") +def authenticated_connector(): + """Authenticate the GoogleDriveDataConnector once for the test suite.""" + # Path to the valid credentials JSON file + credentials_path = "pkgs/swarmauri/tests/static/credentials.json" + connector = GoogleDriveDataConnector(credentials_path=credentials_path) + + # Perform authentication once + try: + connector.authenticate() # Requires manual input for the authorization code + except Exception as e: + pytest.fail(f"Authentication failed: {e}") + + return connector + + +@pytest.fixture(scope="module") +def shared_file_id(): + """Return a shared file ID for testing.""" + return {} + + +@pytest.mark.skip(reason="Skipping test_generate_authorization_url") +def test_generate_authorization_url(): + """Test generate_authorization_url without authentication.""" + # Path to the valid credentials JSON file + credentials_path = "pkgs/swarmauri/tests/static/credentials.json" + connector = GoogleDriveDataConnector(credentials_path=credentials_path) + url = connector.generate_authorization_url() + assert isinstance(url, str) + assert "client_id" in url + assert "redirect_uri" in url + assert "https://accounts.google.com/o/oauth2/v2/auth" in url + + +@pytest.mark.skip(reason="Skipping test_fetch_data") +def test_fetch_data(authenticated_connector): + """Test fetching data from Google Drive.""" + documents = authenticated_connector.fetch_data(query="test") + assert isinstance(documents, list) + if documents: + assert all(hasattr(doc, "content") for doc in documents) + assert all(hasattr(doc, "metadata") for doc in documents) + + +@pytest.mark.skip(reason="Skipping test_insert_data") +def test_insert_data(authenticated_connector, shared_file_id): + """Test inserting data into Google Drive.""" + test_data = "Sample content for Google Drive file" + file_id = authenticated_connector.insert_data(test_data, filename="test_file.txt") + assert isinstance(file_id, str) + shared_file_id["file_id"] = file_id + + +@pytest.mark.skip(reason="Skipping test_update_data") +def test_update_data(authenticated_connector, shared_file_id): + """Test updating data in Google Drive.""" + file_id = shared_file_id["file_id"] + updated_content = "Updated content for Google Drive file" + try: + authenticated_connector.update_data(file_id, updated_content) + except Exception as e: + pytest.fail(f"Failed to update file: {e}") + + +@pytest.mark.skip(reason="Skipping test_delete_data") +def test_delete_data(authenticated_connector, shared_file_id): + """Test deleting data from Google Drive.""" + file_id = shared_file_id["file_id"] # Replace with an actual file ID + try: + authenticated_connector.delete_data(file_id) + except Exception as e: + pytest.fail(f"Failed to delete file: {e}") + + +@pytest.mark.skip(reason="Skipping test_connection") +def test_connection(authenticated_connector): + """Test the connection to Google Drive.""" + connection_success = authenticated_connector.test_connection() + assert connection_success is True diff --git a/pkgs/swarmauri/tests/unit/factories/AgentFactory_unit_test.py b/pkgs/swarmauri/tests/unit/factories/AgentFactory_unit_test.py new file mode 100644 index 000000000..5a0254b70 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/factories/AgentFactory_unit_test.py @@ -0,0 +1,66 @@ +import pytest +from swarmauri.factories.concrete.AgentFactory import AgentFactory +import os +from swarmauri.llms.concrete.GroqModel import GroqModel +from swarmauri.agents.concrete import QAAgent +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture(scope="module") +def groq_model(): + API_KEY = os.getenv("GROQ_API_KEY") + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = GroqModel(api_key=API_KEY) + return llm + + +@pytest.fixture(scope="module") +def agent_factory(): + return AgentFactory() + + +@pytest.mark.unit +def test_ubc_resource(agent_factory): + assert agent_factory.resource == "Factory" + + +@pytest.mark.unit +def test_ubc_type(agent_factory): + assert agent_factory.type == "AgentFactory" + + +@pytest.mark.unit +def test_serialization(agent_factory): + assert ( + agent_factory.id + == AgentFactory.model_validate_json(agent_factory.model_dump_json()).id + ) + + +@pytest.mark.unit +def test_agent_factory_register_and_create(agent_factory, groq_model): + + agent_factory.register(type="QAAgent", resource_class=QAAgent) + + # Create an instance + instance = agent_factory.create(type="QAAgent", llm=groq_model) + assert isinstance(instance, QAAgent) + assert instance.type == "QAAgent" + + +@pytest.mark.unit +def test_agent_factory_create_unregistered_type(agent_factory): + + # Attempt to create an unregistered type + with pytest.raises(ValueError, match="Type 'UnregisteredType' is not registered."): + agent_factory.create(type="UnregisteredType") + + +@pytest.mark.unit +def test_agent_factory_get_agents(agent_factory): + + assert agent_factory.get_agents() == ["QAAgent"] + assert len(agent_factory.get_agents()) == 1 diff --git a/pkgs/swarmauri/tests/unit/factories/Factory_unit_test.py b/pkgs/swarmauri/tests/unit/factories/Factory_unit_test.py new file mode 100644 index 000000000..11695f7ec --- /dev/null +++ b/pkgs/swarmauri/tests/unit/factories/Factory_unit_test.py @@ -0,0 +1,75 @@ +import pytest +from swarmauri.factories.concrete.Factory import Factory +from swarmauri.parsers.concrete.BeautifulSoupElementParser import ( + BeautifulSoupElementParser, +) + + +@pytest.fixture(scope="module") +def factory(): + return Factory() + + +@pytest.mark.unit +def test_ubc_resource(factory): + assert factory.resource == "Factory" + + +@pytest.mark.unit +def test_ubc_type(factory): + assert factory.type == "Factory" + + +@pytest.mark.unit +def test_serialization(factory): + assert factory.id == Factory.model_validate_json(factory.model_dump_json()).id + + +@pytest.mark.unit +def test_factory_register_create_resource(factory): + + # Register a resource and type + factory.register("Parser", "BeautifulSoupElementParser", BeautifulSoupElementParser) + + html_content = "

Sample HTML content

" + + # Create an instance + instance = factory.create( + "Parser", "BeautifulSoupElementParser", element=html_content + ) + assert isinstance(instance, BeautifulSoupElementParser) + assert instance.type == "BeautifulSoupElementParser" + + +@pytest.mark.unit +def test_factory_create_unregistered_resource(factory): + + # Attempt to create an instance of an unregistered resource + with pytest.raises( + ModuleNotFoundError, match="Resource 'UnknownResource' is not registered." + ): + factory.create("UnknownResource", "BeautifulSoupElementParser") + + +@pytest.mark.unit +def test_factory_duplicate_register(factory): + + # Attempt to register the same type again + with pytest.raises( + ValueError, + match="Type 'BeautifulSoupElementParser' is already registered under resource 'Parser'.", + ): + factory.register( + "Parser", "BeautifulSoupElementParser", BeautifulSoupElementParser + ) + + +@pytest.mark.unit +def test_factory_create_unregistered_type(factory): + + # Attempt to create an instance of an unregistered type + with pytest.raises( + ValueError, + match="Type 'UnknownType' is not registered under resource 'Parser'.", + ): + factory.create("Parser", "UnknownType") diff --git a/pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py similarity index 95% rename from pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py index 706d03b61..5fbd06c8a 100644 --- a/pkgs/swarmauri/tests/unit/llms/BlackForestImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/BlackForestImgGenModel_unit_test.py @@ -1,7 +1,7 @@ import pytest import os from dotenv import load_dotenv -from swarmauri.llms.concrete.BlackForestImgGenModel import ( +from swarmauri.image_gens.concrete.BlackForestImgGenModel import ( BlackForestImgGenModel, ) @@ -30,7 +30,7 @@ def get_allowed_models(): @timeout(5) @pytest.mark.unit def test_model_resource(blackforest_imggen_model): - assert blackforest_imggen_model.resource == "LLM" + assert blackforest_imggen_model.resource == "ImageGen" @timeout(5) diff --git a/pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py index 98b3b7047..5492ff573 100644 --- a/pkgs/swarmauri/tests/unit/llms/DeepInfraImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/DeepInfraImgGenModel_unit_test.py @@ -1,6 +1,6 @@ import pytest import os -from swarmauri.llms.concrete.DeepInfraImgGenModel import DeepInfraImgGenModel +from swarmauri.image_gens.concrete.DeepInfraImgGenModel import DeepInfraImgGenModel from dotenv import load_dotenv from swarmauri.utils.timeout_wrapper import timeout diff --git a/pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py rename to pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py index bf5b6d83f..858414f7f 100644 --- a/pkgs/swarmauri/tests/unit/llms/FalAIImgGenModel_unit_test.py +++ b/pkgs/swarmauri/tests/unit/image_gens/FalAIImgGenModel_unit_test.py @@ -1,6 +1,6 @@ import pytest import os -from swarmauri.llms.concrete.FalAIImgGenModel import FalAIImgGenModel +from swarmauri.image_gens.concrete.FalAIImgGenModel import FalAIImgGenModel from dotenv import load_dotenv from swarmauri.utils.timeout_wrapper import timeout diff --git a/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py b/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py new file mode 100644 index 000000000..3772b4dce --- /dev/null +++ b/pkgs/swarmauri/tests/unit/image_gens/HyperbolicImgGenModel_unit_test.py @@ -0,0 +1,118 @@ +import pytest +import os +from swarmauri.image_gens.concrete.HyperbolicImgGenModel import HyperbolicImgGenModel +from dotenv import load_dotenv + +from swarmauri.utils.timeout_wrapper import timeout + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_imggen_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + model = HyperbolicImgGenModel(api_key=API_KEY) + return model + + +def get_allowed_models(): + if not API_KEY: + return [] + model = HyperbolicImgGenModel(api_key=API_KEY) + return model.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.type == "HyperbolicImgGenModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_imggen_model): + assert ( + hyperbolic_imggen_model.id + == HyperbolicImgGenModel.model_validate_json( + hyperbolic_imggen_model.model_dump_json() + ).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_name(hyperbolic_imggen_model): + assert hyperbolic_imggen_model.name == "SDXL1.0-base" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_generate_image_base64(hyperbolic_imggen_model, model_name): + model = hyperbolic_imggen_model + model.name = model_name + + prompt = "A cute cat playing with a ball of yarn" + + image_base64 = model.generate_image_base64(prompt=prompt) + + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_agenerate_image_base64(hyperbolic_imggen_model, model_name): + model = hyperbolic_imggen_model + model.name = model_name + + prompt = "A serene landscape with mountains and a lake" + + image_base64 = await model.agenerate_image_base64(prompt=prompt) + + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch_base64(hyperbolic_imggen_model): + prompts = [ + "A futuristic city skyline", + "A tropical beach at sunset", + ] + + result_base64_images = hyperbolic_imggen_model.batch_base64(prompts=prompts) + + assert len(result_base64_images) == len(prompts) + for image_base64 in result_base64_images: + assert isinstance(image_base64, str) + assert len(image_base64) > 0 + + +@timeout(5) +@pytest.mark.asyncio +@pytest.mark.unit +async def test_abatch_base64(hyperbolic_imggen_model): + prompts = [ + "An abstract painting with vibrant colors", + "A snowy mountain peak", + ] + + result_base64_images = await hyperbolic_imggen_model.abatch_base64(prompts=prompts) + + assert len(result_base64_images) == len(prompts) + for image_base64 in result_base64_images: + assert isinstance(image_base64, str) + assert len(image_base64) > 0 diff --git a/pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py b/pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py similarity index 97% rename from pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py rename to pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py index 7780ba042..b22b9e6ea 100644 --- a/pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py +++ b/pkgs/swarmauri/tests/unit/image_gens/OpenAIImgGenModel_unit_tesst.py @@ -1,7 +1,7 @@ import pytest import os from dotenv import load_dotenv -from swarmauri.llms.concrete.OpenAIImgGenModel import OpenAIImgGenModel +from swarmauri.image_gens.concrete.OpenAIImgGenModel import OpenAIImgGenModel from swarmauri.utils.timeout_wrapper import timeout load_dotenv() diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py new file mode 100644 index 000000000..9b81c84ef --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicAudioTTS_unit_test.py @@ -0,0 +1,141 @@ +import logging +import pytest +import os + +from swarmauri.llms.concrete.HyperbolicAudioTTS import HyperbolicAudioTTS as LLM +from dotenv import load_dotenv +from swarmauri.utils.timeout_wrapper import timeout +from pathlib import Path + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +# Get the current working directory +root_dir = Path(__file__).resolve().parents[2] + +# Construct file paths dynamically +file_path = os.path.join(root_dir, "static", "hyperbolic_test_tts.mp3") +file_path2 = os.path.join(root_dir, "static", "hyperbolic_test2.mp3") +file_path3 = os.path.join(root_dir, "static", "hyperbolic_test3.mp3") + + +@pytest.fixture(scope="module") +def hyperbolic_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +@timeout(5) +def get_allowed_languages(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_languages + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_model): + assert hyperbolic_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_model): + assert hyperbolic_model.type == "HyperbolicAudioTTS" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_model): + assert ( + hyperbolic_model.id + == LLM.model_validate_json(hyperbolic_model.model_dump_json()).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_speed(hyperbolic_model): + assert hyperbolic_model.speed == 1.0 + + +@timeout(5) +@pytest.mark.parametrize("language", get_allowed_languages()) +@pytest.mark.unit +def test_predict(hyperbolic_model, language): + """ + Test prediction with different languages + Note: Adjust the text according to the language if needed + """ + # Set the language for the test + hyperbolic_model.language = language + + # Select an appropriate text based on the language + texts = { + "EN": "Hello, this is a test of text-to-speech output in English.", + "ES": "Hola, esta es una prueba de salida de texto a voz en español.", + "FR": "Bonjour, ceci est un test de sortie de texte en français.", + "ZH": "这是一个中文语音转换测试。", + "JP": "これは日本語の音声合成テストです。", + "KR": "이것은 한국어 음성 합성 테스트입니다.", + } + + text = texts.get( + language, "Hello, this is a generic test of text-to-speech output." + ) + + audio_file_path = hyperbolic_model.predict(text=text, audio_path=file_path) + + logging.info(audio_file_path) + + assert isinstance(audio_file_path, str) + assert os.path.exists(audio_file_path) + assert os.path.getsize(audio_file_path) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch(hyperbolic_model): + """ + Test batch processing of multiple texts + """ + text_path_dict = { + "Hello": file_path, + "Hi there": file_path2, + "Good morning": file_path3, + } + + results = hyperbolic_model.batch(text_path_dict=text_path_dict) + assert len(results) == len(text_path_dict) + + for result in results: + assert isinstance(result, str) + assert os.path.exists(result) + assert os.path.getsize(result) > 0 + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.unit +async def test_abatch(hyperbolic_model): + """ + Test asynchronous batch processing of multiple texts + """ + text_path_dict = { + "Hello": file_path, + "Hi there": file_path2, + "Good morning": file_path3, + } + + results = await hyperbolic_model.abatch(text_path_dict=text_path_dict) + assert len(results) == len(text_path_dict) + + for result in results: + assert isinstance(result, str) + assert os.path.exists(result) + assert os.path.getsize(result) > 0 diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py new file mode 100644 index 000000000..6a4bd9652 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicModel_unit_test.py @@ -0,0 +1,218 @@ +import logging +import pytest +import os + +from swarmauri.llms.concrete.HyperbolicModel import HyperbolicModel as LLM +from swarmauri.conversations.concrete.Conversation import Conversation + +from swarmauri.messages.concrete.HumanMessage import HumanMessage +from swarmauri.messages.concrete.SystemMessage import SystemMessage + +from swarmauri.messages.concrete.AgentMessage import UsageData + +from swarmauri.utils.timeout_wrapper import timeout + +from dotenv import load_dotenv + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +def get_allowed_models(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_model): + assert hyperbolic_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_model): + assert hyperbolic_model.type == "HyperbolicModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_model): + assert ( + hyperbolic_model.id + == LLM.model_validate_json(hyperbolic_model.model_dump_json()).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_name(hyperbolic_model): + assert hyperbolic_model.name == "meta-llama/Meta-Llama-3.1-8B-Instruct" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_no_system_context(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + usage_data = conversation.get_last().usage + + logging.info(usage_data) + + assert type(prediction) is str + assert isinstance(usage_data, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_preamble_system_context(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + system_context = 'You only respond with the following phrase, "Jeff"' + human_message = SystemMessage(content=system_context) + conversation.add_message(human_message) + + input_data = "Hi" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + usage_data = conversation.get_last().usage + + logging.info(usage_data) + + assert type(prediction) is str + assert "Jeff" in prediction + assert isinstance(usage_data, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_stream(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a cat." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + collected_tokens = [] + for token in model.stream(conversation=conversation): + logging.info(token) + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_apredict(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + result = await model.apredict(conversation=conversation) + prediction = result.get_last().content + assert isinstance(prediction, str) + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_astream(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a dog." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + + collected_tokens = [] + async for token in model.astream(conversation=conversation): + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_batch(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + + results = model.batch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +async def test_abatch(hyperbolic_model, model_name): + model = hyperbolic_model + model.name = model_name + + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + + results = await model.abatch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) diff --git a/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py new file mode 100644 index 000000000..495341aae --- /dev/null +++ b/pkgs/swarmauri/tests/unit/llms/HyperbolicVisionModel_unit_test.py @@ -0,0 +1,158 @@ +import pytest +import os +from swarmauri.llms.concrete.HyperbolicVisionModel import HyperbolicVisionModel +from swarmauri.conversations.concrete.Conversation import Conversation +from swarmauri.messages.concrete.HumanMessage import HumanMessage +from dotenv import load_dotenv +from swarmauri.utils.timeout_wrapper import timeout + +load_dotenv() + +API_KEY = os.getenv("HYPERBOLIC_API_KEY") + + +@pytest.fixture(scope="module") +def hyperbolic_vision_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + model = HyperbolicVisionModel(api_key=API_KEY) + return model + + +def get_allowed_models(): + if not API_KEY: + return [] + model = HyperbolicVisionModel(api_key=API_KEY) + return model.allowed_models + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(hyperbolic_vision_model): + assert hyperbolic_vision_model.resource == "LLM" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(hyperbolic_vision_model): + assert hyperbolic_vision_model.type == "HyperbolicVisionModel" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(hyperbolic_vision_model): + assert ( + hyperbolic_vision_model.id + == HyperbolicVisionModel.model_validate_json( + hyperbolic_vision_model.model_dump_json() + ).id + ) + + +@timeout(5) +@pytest.mark.unit +def test_default_model_name(hyperbolic_vision_model): + assert hyperbolic_vision_model.name == "Qwen/Qwen2-VL-72B-Instruct" + + +def create_test_conversation(image_url, prompt): + conversation = Conversation() + conversation.add_message( + HumanMessage( + content=[ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ] + ) + ) + return conversation + + +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +def test_predict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Who painted this artwork?" + conversation = create_test_conversation(image_url, prompt) + + result = model.predict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", get_allowed_models()) +@timeout(5) +@pytest.mark.unit +async def test_apredict(hyperbolic_vision_model, model_name): + model = hyperbolic_vision_model + model.name = model_name + + image_url = "https://llava-vl.github.io/static/images/monalisa.jpg" + prompt = "Describe the woman in the painting." + conversation = create_test_conversation(image_url, prompt) + + result = await model.apredict(conversation) + + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@timeout(5) +@pytest.mark.unit +def test_batch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = hyperbolic_vision_model.batch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 + + +@pytest.mark.asyncio +@timeout(5) +@pytest.mark.unit +async def test_abatch(hyperbolic_vision_model): + image_urls = [ + "https://llava-vl.github.io/static/images/monalisa.jpg", + "https://llava-vl.github.io/static/images/monalisa.jpg", + ] + prompts = [ + "Who painted this artwork?", + "Describe the woman in the painting.", + ] + + conversations = [ + create_test_conversation(image_url, prompt) + for image_url, prompt in zip(image_urls, prompts) + ] + + results = await hyperbolic_vision_model.abatch(conversations) + + assert len(results) == len(image_urls) + for result in results: + assert result.history[-1].content is not None + assert isinstance(result.history[-1].content, str) + assert len(result.history[-1].content) > 0 diff --git a/pkgs/swarmauri/tests/unit/pipelines/Pipeline_unit_test.py b/pkgs/swarmauri/tests/unit/pipelines/Pipeline_unit_test.py new file mode 100644 index 000000000..359d9d759 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/pipelines/Pipeline_unit_test.py @@ -0,0 +1,84 @@ +import pytest +from swarmauri.pipelines.concrete.Pipeline import Pipeline +from swarmauri_core.pipelines.IPipeline import PipelineStatus + + +@pytest.fixture(scope="module") +def simple_tasks(): + def task1(): + return "Task 1 completed" + + def task2(x): + return f"Task 2 with {x}" + + def task3(x, y): + return x + y + + return [task1, task2, task3] + + +@pytest.fixture(scope="module") +def pipeline(simple_tasks): + pipeline = Pipeline() + pipeline.add_task(simple_tasks[0]) + pipeline.add_task(simple_tasks[1], "parameter") + pipeline.add_task(simple_tasks[2], 10, 20) + return pipeline + + +@pytest.mark.unit +def test_ubc_resource(pipeline): + assert pipeline.resource == "Pipeline" + + +@pytest.mark.unit +def test_ubc_type(pipeline): + assert pipeline.type == "Pipeline" + + +@pytest.mark.unit +def test_serialization(pipeline): + assert pipeline.id == Pipeline.model_validate_json(pipeline.model_dump_json()).id + + +@pytest.mark.unit +def test_pipeline_initial_status(pipeline): + assert pipeline.get_status() == PipelineStatus.PENDING + + +@pytest.mark.unit +def test_pipeline_execution(pipeline): + results = pipeline.execute() + + assert len(results) == 3 + assert results[0] == "Task 1 completed" + assert results[1] == "Task 2 with parameter" + assert results[2] == 30 + assert pipeline.get_status() == PipelineStatus.COMPLETED + + +@pytest.mark.unit +def test_pipeline_reset(pipeline): + pipeline.reset() + assert pipeline.get_status() == PipelineStatus.PENDING + assert len(pipeline.get_results()) == 0 + + +@pytest.mark.unit +def test_pipeline_add_task(simple_tasks): + pipeline = Pipeline() + initial_task_count = len(pipeline.tasks) + + pipeline.add_task(simple_tasks[0]) + assert len(pipeline.tasks) == initial_task_count + 1 + + +@pytest.mark.unit +def test_pipeline_get_results(simple_tasks): + pipeline = Pipeline() + pipeline.add_task(simple_tasks[0]) + pipeline.execute() + + results = pipeline.get_results() + assert len(results) == 1 + assert results[0] == "Task 1 completed" diff --git a/pkgs/swarmauri/tests/unit/service_registries/ServiceRegistry_unit_test.py b/pkgs/swarmauri/tests/unit/service_registries/ServiceRegistry_unit_test.py new file mode 100644 index 000000000..6498d4c29 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/service_registries/ServiceRegistry_unit_test.py @@ -0,0 +1,82 @@ +import pytest +from swarmauri.service_registries.concrete.ServiceRegistry import ServiceRegistry + + +@pytest.fixture +def service_registry(): + return ServiceRegistry() + + +@pytest.mark.unit +def test_ubc_resource(service_registry): + assert service_registry.resource == "ServiceRegistry" + + +@pytest.mark.unit +def test_ubc_type(service_registry): + assert service_registry.type == "ServiceRegistry" + + +@pytest.mark.unit +def test_serialization(service_registry): + assert ( + service_registry.id + == service_registry.model_validate_json(service_registry.model_dump_json()).id + ) + + +@pytest.mark.unit +def test_register_service(service_registry): + service_registry.register_service("auth", {"role": "authentication"}) + assert service_registry.services["auth"] == {"role": "authentication"} + + +@pytest.mark.unit +def test_get_service(service_registry): + service_registry.register_service("auth", {"role": "authentication"}) + service = service_registry.get_service("auth") + assert service == {"role": "authentication"} + assert service_registry.get_service("nonexistent") is None + + +@pytest.mark.unit +def test_get_services_by_roles(service_registry): + service_registry.register_service("auth", {"role": "authentication"}) + service_registry.register_service("db", {"role": "database"}) + recipients = service_registry.get_services_by_roles(["authentication"]) + assert recipients == ["auth"] + recipients = service_registry.get_services_by_roles(["database"]) + assert recipients == ["db"] + recipients = service_registry.get_services_by_roles(["authentication", "database"]) + assert set(recipients) == {"auth", "db"} + + +@pytest.mark.unit +def test_unregister_service(service_registry): + service_registry.register_service("auth", {"role": "authentication"}) + service_registry.unregister_service("auth") + assert "auth" not in service_registry.services + + +@pytest.mark.unit +def test_unregister_service_nonexistent(service_registry): + with pytest.raises(ValueError) as exc_info: + service_registry.unregister_service("nonexistent") + assert str(exc_info.value) == "Service nonexistent not found." + + +@pytest.mark.unit +def test_update_service(service_registry): + service_registry.register_service("auth", {"role": "authentication"}) + service_registry.update_service("auth", {"role": "auth_service", "version": "1.0"}) + assert service_registry.services["auth"] == { + "role": "auth_service", + "version": "1.0", + } + + +@pytest.mark.unit +def test_update_service_nonexistent(service_registry): + with pytest.raises(ValueError) as exc_info: + service_registry.update_service("nonexistent", {"role": "new_role"}) + assert str(exc_info.value) == "Service nonexistent not found." diff --git a/pkgs/swarmauri/tests/unit/state/DictState_unit_test.py b/pkgs/swarmauri/tests/unit/state/DictState_unit_test.py new file mode 100644 index 000000000..8450efb47 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/state/DictState_unit_test.py @@ -0,0 +1,85 @@ +import pytest +from swarmauri.state.concrete.DictState import DictState + + +@pytest.fixture +def dict_state(): + """ + Fixture to create a DictState instance for testing. + """ + # Create DictState + state = DictState() + + # Yield the state for tests to use + yield state + + +@pytest.mark.unit +def test_resource_type(dict_state): + """ + Test the resource type of the DictState. + """ + assert dict_state.resource == "State" + + +@pytest.mark.unit +def test_write_and_read(dict_state): + """ + Test writing data to DictState and reading it back. + """ + test_data = {"key1": "value1", "key2": 42} + dict_state.write(test_data) + read_data = dict_state.read() + assert read_data == test_data + + +@pytest.mark.unit +def test_update(dict_state): + """ + Test updating existing DictState data. + """ + # Initial write + initial_data = {"existing_key": "existing_value"} + dict_state.write(initial_data) + + # Update with new data + update_data = {"new_key": "new_value"} + dict_state.update(update_data) + + # Read and verify merged data + read_data = dict_state.read() + expected_data = {"existing_key": "existing_value", "new_key": "new_value"} + assert read_data == expected_data + + +@pytest.mark.unit +def test_reset(dict_state): + """ + Test resetting the DictState to an empty dictionary. + """ + # Write some data + dict_state.write({"some_key": "some_value"}) + + # Reset + dict_state.reset() + + # Verify empty state + assert dict_state.read() == {} + + +@pytest.mark.unit +def test_deep_copy(dict_state): + # Write initial data + initial_data = {"key1": "value1", "key2": "value2"} + dict_state.write(initial_data) + + # Create deep copy + copied_state = dict_state.deep_copy() + + # Verify copied state + assert isinstance(copied_state, DictState) + assert copied_state.read() == initial_data + + # Verify deep copy by modifying original and copy independently + dict_state.update({"new_key": "new_value"}) + assert copied_state.read() == initial_data # Copy should remain unchanged diff --git a/pkgs/swarmauri/tests/unit/swarms/Swarm_unit_test.py b/pkgs/swarmauri/tests/unit/swarms/Swarm_unit_test.py new file mode 100644 index 000000000..966ee48c9 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/swarms/Swarm_unit_test.py @@ -0,0 +1,91 @@ +from typing import Any +from pydantic import BaseModel +import pytest +import asyncio +from swarmauri.swarms.base.SwarmBase import SwarmStatus +from swarmauri.swarms.concrete.Swarm import Swarm + + +class MockAgent(BaseModel): + async def process(self, task: Any, **kwargs) -> str: + if task == "fail": + raise Exception("Task failed") + return f"Processed {task}" + +@pytest.fixture +def swarm(): + return Swarm(agent_class=MockAgent, num_agents=3, agent_timeout=0.5, max_retries=2) + +@pytest.mark.unit +def test_ubc_resource(swarm): + assert swarm.resource == "Swarm" + + +@pytest.mark.unit +def test_ubc_type(swarm): + assert swarm.type == "Swarm" + + +@pytest.mark.unit +def test_serialization(swarm): + assert swarm.id == Swarm.model_validate_json(swarm.model_dump_json()).id + + +@pytest.mark.asyncio +async def test_swarm_initialization(swarm): + assert len(swarm.agents) == 3 + assert swarm.queue_size == 0 + assert all(s == SwarmStatus.IDLE for s in swarm.get_swarm_status().values()) + + +@pytest.mark.asyncio +async def test_single_task_execution(swarm): + results = await swarm.exec("task1") + assert len(results) == 1 + assert results[0]["status"] == SwarmStatus.COMPLETED + assert results[0]["result"] == "Processed task1" + assert "agent_id" in results[0] + + +@pytest.mark.asyncio +async def test_multiple_tasks_execution(swarm): + tasks = ["task1", "task2", "task3"] + results = await swarm.exec(tasks) + assert len(results) == 3 + assert all(r["status"] == SwarmStatus.COMPLETED for r in results) + assert all("Processed" in r["result"] for r in results) + + +@pytest.mark.asyncio +async def test_failed_task_handling(swarm): + results = await swarm.exec("fail") + assert len(results) == 1 + assert results[0]["status"] == SwarmStatus.FAILED + assert "error" in results[0] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_swarm_status_changes(swarm): + # Create tasks + tasks = ["task1"] * 3 + + # Start execution + task_future = asyncio.create_task(swarm.exec(tasks)) + + # Wait briefly for tasks to start + await asyncio.sleep(0.1) + + # Check intermediate status + status = swarm.get_swarm_status() + assert any( + s in [SwarmStatus.WORKING, SwarmStatus.COMPLETED] for s in status.values() + ) + + # Wait for completion with timeout + try: + results = await asyncio.wait_for(task_future, timeout=2.0) + assert len(results) == len(tasks) + assert all(r["status"] == SwarmStatus.COMPLETED for r in results) + except asyncio.TimeoutError: + task_future.cancel() + raise diff --git a/pkgs/swarmauri/tests/unit/task_mgt_strategies/RoundRobinStrategy_unit_test.py b/pkgs/swarmauri/tests/unit/task_mgt_strategies/RoundRobinStrategy_unit_test.py new file mode 100644 index 000000000..c68a1a7d9 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/task_mgt_strategies/RoundRobinStrategy_unit_test.py @@ -0,0 +1,132 @@ +import pytest +from unittest.mock import MagicMock +from swarmauri.task_mgt_strategies.concrete.RoundRobinStrategy import RoundRobinStrategy + + +@pytest.fixture +def round_robin_strategy(): + """Fixture to create a RoundRobinStrategy instance.""" + return RoundRobinStrategy() + + +def test_assign_task(round_robin_strategy): + # Setup + task = {"task_id": "task1", "payload": "data"} + service_registry = MagicMock(return_value=["service1", "service2"]) + + # Execute + round_robin_strategy.assign_task(task, service_registry) + + # Verify + assert round_robin_strategy.task_assignments["task1"] == "service1" + assert round_robin_strategy.current_index == 1 + + +def test_assign_task_no_services(round_robin_strategy): + # Setup + task = {"task_id": "task1", "payload": "data"} + service_registry = MagicMock(return_value=[]) + + # Execute & Verify + with pytest.raises(ValueError) as exc_info: + round_robin_strategy.assign_task(task, service_registry) + assert str(exc_info.value) == "No services available for task assignment." + + +def test_add_task(round_robin_strategy): + # Setup + task = {"task_id": "task1", "payload": "data"} + + # Execute + round_robin_strategy.add_task(task) + + # Verify + assert not round_robin_strategy.task_queue.empty() + queued_task = round_robin_strategy.task_queue.get() + assert queued_task == task + + +def test_remove_task(round_robin_strategy): + # Setup + task_id = "task1" + round_robin_strategy.task_assignments[task_id] = "service1" + + # Execute + round_robin_strategy.remove_task(task_id) + + # Verify + assert task_id not in round_robin_strategy.task_assignments + + +def test_remove_task_not_found(round_robin_strategy): + # Setup + task_id = "task1" + + # Execute & Verify + with pytest.raises(ValueError) as exc_info: + round_robin_strategy.remove_task(task_id) + assert str(exc_info.value) == "Task 'task1' not found in assignments." + + +def test_get_task(round_robin_strategy): + # Setup + task_id = "task1" + round_robin_strategy.task_assignments[task_id] = "service1" + + # Execute + result = round_robin_strategy.get_task(task_id) + + # Verify + expected_result = {"task_id": task_id, "assigned_service": "service1"} + assert result == expected_result + + +def test_get_task_not_found(round_robin_strategy): + # Setup + task_id = "task1" + + # Execute & Verify + with pytest.raises(ValueError) as exc_info: + round_robin_strategy.get_task(task_id) + assert str(exc_info.value) == "Task 'task1' not found in assignments." + + +def test_process_tasks(round_robin_strategy): + # Setup + service_registry = MagicMock(return_value=["service1", "service2"]) + transport = MagicMock() + tasks = [ + {"task_id": "task1", "payload": "data1"}, + {"task_id": "task2", "payload": "data2"}, + {"task_id": "task3", "payload": "data3"}, + ] + for task in tasks: + round_robin_strategy.add_task(task) + + # Execute + round_robin_strategy.process_tasks(service_registry, transport) + + # Verify assignments + assert round_robin_strategy.task_assignments["task1"] == "service1" + assert round_robin_strategy.task_assignments["task2"] == "service2" + assert round_robin_strategy.task_assignments["task3"] == "service1" + assert round_robin_strategy.current_index == 3 + + # Verify that transport.send was called correctly + transport.send.assert_any_call(tasks[0], "service1") + transport.send.assert_any_call(tasks[1], "service2") + transport.send.assert_any_call(tasks[2], "service1") + assert transport.send.call_count == 3 + + +def test_process_tasks_no_services(round_robin_strategy): + # Setup + service_registry = MagicMock(return_value=[]) + transport = MagicMock() + task = {"task_id": "task1", "payload": "data"} + round_robin_strategy.add_task(task) + + # Execute & Verify + with pytest.raises(ValueError) as exc_info: + round_robin_strategy.process_tasks(service_registry, transport) + assert "No services available for task assignment." in str(exc_info.value) diff --git a/pkgs/swarmauri/tests/unit/transports/PubSubTransport_unit_test.py b/pkgs/swarmauri/tests/unit/transports/PubSubTransport_unit_test.py new file mode 100644 index 000000000..fb5aeeb5e --- /dev/null +++ b/pkgs/swarmauri/tests/unit/transports/PubSubTransport_unit_test.py @@ -0,0 +1,138 @@ +import pytest +import asyncio +from uuid import UUID +from typing import Any +from swarmauri.transports.concrete.PubSubTransport import ( + PubSubTransport, +) +from swarmauri.utils.timeout_wrapper import timeout +import logging + + +@pytest.fixture +def pubsub_transport(): + transport = PubSubTransport() + return transport + + +@timeout(5) +@pytest.mark.unit +def test_ubc_resource(pubsub_transport): + assert pubsub_transport.resource == "Transport" + + +@timeout(5) +@pytest.mark.unit +def test_ubc_type(pubsub_transport): + assert pubsub_transport.type == "PubSubTransport" + + +@timeout(5) +@pytest.mark.unit +def test_serialization(pubsub_transport): + assert ( + pubsub_transport.id + == PubSubTransport.model_validate_json(pubsub_transport.model_dump_json()).id + ) + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_subscribe(pubsub_transport): + topic = "test_topic" + subscriber_id = await pubsub_transport.subscribe(topic) + + # Validate subscriber ID format + assert isinstance(UUID(subscriber_id), UUID) + + # Ensure subscriber is added to the topic + assert subscriber_id in pubsub_transport._topics[topic] + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_unsubscribe(pubsub_transport): + topic = "test_topic" + subscriber_id = await pubsub_transport.subscribe(topic) + + await pubsub_transport.unsubscribe(topic, subscriber_id) + + # Ensure subscriber is removed from the topic + assert subscriber_id not in pubsub_transport._topics.get(topic, set()) + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_publish_and_receive(pubsub_transport): + topic = "test_topic" + subscriber_id = await pubsub_transport.subscribe(topic) + + message = "Hello, PubSub!" + await pubsub_transport.publish(topic, message) + + # Ensure the subscriber receives the message + received_message = await pubsub_transport.receive(subscriber_id) + assert received_message == message + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_broadcast(pubsub_transport): + topic1 = "topic1" + topic2 = "topic2" + subscriber_id1 = await pubsub_transport.subscribe(topic1) + subscriber_id2 = await pubsub_transport.subscribe(topic2) + + message = "Broadcast Message" + pubsub_transport.broadcast("sender_id", message) + + # Ensure both subscribers receive the message + received_message1 = await pubsub_transport.receive(subscriber_id1) + received_message2 = await pubsub_transport.receive(subscriber_id2) + assert received_message1 == message + assert received_message2 == message + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_multicast(pubsub_transport): + topic1 = "topic1" + topic2 = "topic2" + topic3 = "topic3" + subscriber_id1 = await pubsub_transport.subscribe(topic1) + subscriber_id2 = await pubsub_transport.subscribe(topic2) + subscriber_id3 = await pubsub_transport.subscribe(topic3) + + message = "Multicast Message" + pubsub_transport.multicast("sender_id", [topic1, topic2], message) + + # Ensure only subscribers of specified topics receive the message + received_message1 = await pubsub_transport.receive(subscriber_id1) + received_message2 = await pubsub_transport.receive(subscriber_id2) + assert received_message1 == message + assert received_message2 == message + + try: + await asyncio.wait_for(pubsub_transport.receive(subscriber_id3), timeout=1.0) + pytest.fail("Expected no message, but received one.") + except asyncio.TimeoutError: + pass + + +@timeout(5) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_receive_no_messages(pubsub_transport): + topic = "test_topic" + subscriber_id = await pubsub_transport.subscribe(topic) + + try: + await asyncio.wait_for(pubsub_transport.receive(subscriber_id), timeout=1.0) + pytest.fail("Expected no message, but received one.") + except asyncio.TimeoutError: + pass diff --git a/pkgs/swarmauri/tests/unit/utils/method_signature_extractor_decorator_test.py b/pkgs/swarmauri/tests/unit/utils/method_signature_extractor_decorator_test.py new file mode 100644 index 000000000..ea14ca385 --- /dev/null +++ b/pkgs/swarmauri/tests/unit/utils/method_signature_extractor_decorator_test.py @@ -0,0 +1,172 @@ +from typing import List, Optional, Union, Dict, Any +from swarmauri.utils.method_signature_extractor_decorator import ( + MethodSignatureExtractor, + extract_method_signature, +) + + +# Test functions with various signature types +def simple_function(x: int, y: str): + return x, y + + +def optional_function(a: Optional[int] = None, b: str = "default"): + return a, b + + +def list_function(items: List[str], count: int = 1): + return items, count + + +def union_function(value: Union[int, str]): + return value + + +def complex_function( + x: int, y: Optional[List[float]] = None, z: Union[str, Dict[str, Any]] = "default" +): + return x, y, z + + +class TestMethodSignatureExtractor: + def test_simple_function_extraction(self): + """Test extraction of a simple function with basic types""" + extractor = MethodSignatureExtractor(method=simple_function) + + assert len(extractor.parameters) == 2 + + # Check first parameter + assert extractor.parameters[0].name == "x" + assert extractor.parameters[0].type == "integer" + assert extractor.parameters[0].required is True + + # Check second parameter + assert extractor.parameters[1].name == "y" + assert extractor.parameters[1].type == "string" + assert extractor.parameters[1].required is True + + def test_optional_function_extraction(self): + """Test extraction of a function with optional parameters""" + extractor = MethodSignatureExtractor(method=optional_function) + + assert len(extractor.parameters) == 2 + + # Check first parameter (optional) + assert extractor.parameters[0].name == "a" + assert extractor.parameters[0].type == "integer" + assert extractor.parameters[0].required is False + + # Check second parameter (with default) + assert extractor.parameters[1].name == "b" + assert extractor.parameters[1].type == "string" + assert extractor.parameters[1].required is False + + def test_list_function_extraction(self): + """Test extraction of a function with list parameter""" + extractor = MethodSignatureExtractor(method=list_function) + + assert len(extractor.parameters) == 2 + + # Check first parameter (list) + assert extractor.parameters[0].name == "items" + assert extractor.parameters[0].type == "array" + assert extractor.parameters[0].required is True + + # Check second parameter (with default) + assert extractor.parameters[1].name == "count" + assert extractor.parameters[1].type == "integer" + assert extractor.parameters[1].required is False + + def test_union_function_extraction(self): + """Test extraction of a function with union type""" + extractor = MethodSignatureExtractor(method=union_function) + + assert len(extractor.parameters) == 1 + + # Check union parameter + assert extractor.parameters[0].name == "value" + assert extractor.parameters[0].type is not None + assert len(extractor.parameters[0].type) == 2 + + def test_complex_function_extraction(self): + """Test extraction of a function with multiple complex types""" + extractor = MethodSignatureExtractor(method=complex_function) + + assert len(extractor.parameters) == 3 + + # Check first parameter + assert extractor.parameters[0].name == "x" + assert extractor.parameters[0].type == "integer" + assert extractor.parameters[0].required is True + + # Check second parameter (optional list) + assert extractor.parameters[1].name == "y" + assert extractor.parameters[1].type == "array" + assert extractor.parameters[1].required is False + + # Check third parameter (union type with default) + assert extractor.parameters[2].name == "z" + assert extractor.parameters[2].type is not None + assert extractor.parameters[2].required is False + + def test_decorator_signature_extraction(self): + """Test the extract_method_signature decorator""" + + @extract_method_signature + def test_decorator_func(a: int, b: Optional[str] = None): + pass + + # Check if signature_details is added to the function + assert hasattr(test_decorator_func, "signature_details") + + # Verify the details + details = test_decorator_func.signature_details + assert len(details) == 2 + + # First parameter + assert details[0].name == "a" + assert details[0].type == "integer" + assert details[0].required is True + + # Second parameter + assert details[1].name == "b" + assert details[1].type == "string" + assert details[1].required is False + + def test_type_mapping(self): + """Test the type mapping functionality""" + extractor = MethodSignatureExtractor(method=simple_function) + + # Check predefined type mappings + type_mapping = extractor._type_mapping + assert type_mapping[int] == "integer" + assert type_mapping[float] == "number" + assert type_mapping[str] == "string" + assert type_mapping[bool] == "boolean" + assert type_mapping[list] == "array" + assert type_mapping[dict] == "object" + assert type_mapping[Any] == "any" + + +# Additional edge case tests +def test_empty_function(): + """Test function with no parameters""" + + def empty_func(): + pass + + extractor = MethodSignatureExtractor(method=empty_func) + assert len(extractor.parameters) == 0 + + +def test_method_with_self(): + """Test method of a class with self parameter""" + + class TestClass: + def method(self, x: int): + return x + + extractor = MethodSignatureExtractor(method=TestClass.method) + assert len(extractor.parameters) == 1 + assert extractor.parameters[0].name == "x" + assert extractor.parameters[0].type == "integer" diff --git a/scripts/classify_json_results.py b/scripts/classify_json_results.py new file mode 100644 index 000000000..51de4357a --- /dev/null +++ b/scripts/classify_json_results.py @@ -0,0 +1,131 @@ +import json +import sys +from collections import defaultdict + +def parse_arguments(args): + """Parse command-line arguments.""" + import argparse + + parser = argparse.ArgumentParser(description="Analyze test results from a JSON file.") + parser.add_argument("file", help="Path to the JSON file containing test results") + parser.add_argument("--required-passed", type=str, help="Required passed percentage (e.g., 'gt:50', 'lt:30', 'eq:50', 'ge:50', 'le:50')") + parser.add_argument("--required-skipped", type=str, help="Required skipped percentage (e.g., 'gt:20', 'lt:50', 'eq:50', 'ge:50', 'le:50')") + + return parser.parse_args(args) + + +def evaluate_threshold(value, threshold): + """Evaluate if the value meets the threshold condition.""" + try: + op, limit = threshold.split(":") + limit = float(limit) + if op == "gt": + return value > limit + elif op == "lt": + return value < limit + elif op == "eq": + return value == limit + elif op == "ge": + return value >= limit + elif op == "le": + return value <= limit + else: + raise ValueError(f"Invalid operator '{op}'. Use 'gt', 'lt', 'eq', 'ge', or 'le'.") + except ValueError as e: + raise ValueError(f"Invalid threshold format '{threshold}'. Expected format: 'gt:', 'lt:', 'eq:', 'ge:', or 'le:'") from e + + +def analyze_tags_from_file(file_path, required_passed=None, required_skipped=None): + try: + # Load JSON data from the file + with open(file_path, 'r') as f: + data = json.load(f) + + # Extract the summary and list of tests + summary = data.get("summary", {}) + tests = data.get("tests", []) + + # Check if the summary and tests exist + if not summary or not tests: + print("No test data or summary found in the provided file.") + return + + # Get total number of tests + total_tests = summary.get("total", 0) + + # Print summary with percentage + print("\nTest Results Summary:") + print(f"{'Category':<15} {'Count':<10} {'Total':<10} {'% of Total':<10}") + print("-" * 50) + for category in ["passed", "skipped", "failed"]: + count = summary.get(category, 0) + percentage = (count / total_tests) * 100 if total_tests > 0 else 0 + print(f"{category.capitalize():<15} {count:<10} {total_tests:<10} {percentage:<10.2f}") + + # Check thresholds + passed_percentage = (summary.get("passed", 0) / total_tests) * 100 if total_tests > 0 else 0 + skipped_percentage = (summary.get("skipped", 0) / total_tests) * 100 if total_tests > 0 else 0 + + threshold_error = False + if required_passed and not evaluate_threshold(passed_percentage, required_passed): + print(f"\nWARNING: Passed percentage ({passed_percentage:.2f}%) does not meet the condition '{required_passed}'!") + threshold_error = True + if required_skipped and not evaluate_threshold(skipped_percentage, required_skipped): + print(f"WARNING: Skipped percentage ({skipped_percentage:.2f}%) does not meet the condition '{required_skipped}'!") + threshold_error = True + + print("\n") + + # Group tests by tags + tag_outcomes = defaultdict(lambda: {"passed": 0, "total": 0}) + + for test in tests: + outcome = test["outcome"] + for tag in test["keywords"]: + # Exclusion criteria for tags + if ( + tag == "tests" or # Exclude tag "tests" + tag.startswith("test_") or # Exclude tags starting with "test_" + tag.endswith("_test.py") or # Exclude tags ending with "_test.py" + tag == "" # Exclude empty tags + ): + continue + tag_outcomes[tag]["total"] += 1 + if outcome == "passed": + tag_outcomes[tag]["passed"] += 1 + + # Print detailed results by tags + print("Tag-Based Results:") + print(f"{'Tag':<30} {'Passed':<10} {'Total':<10} {'% Passed':<10}") + print("-" * 70) + for tag, outcomes in tag_outcomes.items(): + passed = outcomes["passed"] + total = outcomes["total"] + percentage = (passed / total) * 100 if total > 0 else 0 + print(f"{tag:<30} {passed:<10} {total:<10} {percentage:<10.2f}") + + # Exit with error code if thresholds are not met + if threshold_error: + sys.exit(1) + + except FileNotFoundError: + print(f"Error: File not found: {file_path}") + sys.exit(1) + except json.JSONDecodeError: + print(f"Error: Failed to decode JSON from file: {file_path}") + sys.exit(1) + except Exception as e: + print(f"An unexpected error occurred: {e}") + sys.exit(1) + + +if __name__ == "__main__": + # Parse arguments + args = parse_arguments(sys.argv[1:]) + + # Run the analysis + analyze_tags_from_file( + file_path=args.file, + required_passed=args.required_passed, + required_skipped=args.required_skipped + ) diff --git a/scripts/manage_issues.py b/scripts/manage_issues.py new file mode 100644 index 000000000..bab5fb1ad --- /dev/null +++ b/scripts/manage_issues.py @@ -0,0 +1,165 @@ +import os +import json +import requests +from swarmauri.llms.concrete.GroqModel import GroqModel +from swarmauri.agents.concrete.SimpleConversationAgent import SimpleConversationAgent +from swarmauri.conversations.concrete.Conversation import Conversation +import argparse + +# GitHub API settings +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +REPO = os.getenv("REPO") +HEADERS = { + "Authorization": f"Bearer {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + +# GroqModel Initialization +GROQ_API_KEY = os.getenv("GROQ_API_KEY") +llm = GroqModel(api_key=GROQ_API_KEY) + +BASE_BRANCH = os.getenv("GITHUB_HEAD_REF") or os.getenv("GITHUB_REF", "unknown").split("/")[-1] +COMMIT_SHA = os.getenv("GITHUB_SHA", "unknown") +WORKFLOW_RUN_URL = os.getenv("GITHUB_SERVER_URL", "https://github.com") + f"/{REPO}/actions/runs/{os.getenv('GITHUB_RUN_ID', 'unknown')}" + +def parse_arguments(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Manage GitHub Issues for Test Failures") + parser.add_argument("--results-file", type=str, required=True, help="Path to the pytest JSON report file") + parser.add_argument("--package", type=str, required=True, help="Name of the matrix package where the error occurred.") + return parser.parse_args() + +def load_pytest_results(report_file): + """Load pytest results from the JSON report.""" + if not os.path.exists(report_file): + print(f"Report file not found: {report_file}") + return [] + + with open(report_file, "r") as f: + report = json.load(f) + + failures = [] + for test in report.get("tests", []): + if test.get("outcome") == "failed": + test_name = test.get("nodeid", "Unknown test") + failure_message = test.get("call", {}).get("longrepr", "No failure details available") + failures.append({ + "name": test_name, + "path": test.get("nodeid", "Unknown path"), + "message": failure_message + }) + return failures + +def ask_groq_for_fix(test_name, failure_message, stack_trace): + """Ask GroqModel for suggestions on fixing the test failure.""" + prompt = f""" + I have a failing test case named '{test_name}' in a Python project. The error message is: + {failure_message} + + The stack trace is: + {stack_trace} + + Can you help me identify the cause of this failure and suggest a fix? + """ + try: + agent = SimpleConversationAgent(llm=llm, conversation=Conversation()) + response = agent.exec(input_str=prompt) + return response + except Exception as e: + print(f"Error communicating with Groq: {e}") + return "Unable to retrieve suggestions from Groq at this time." + +def get_existing_issues(): + """Retrieve all existing issues with the pytest-failure label.""" + url = f"https://api.github.com/repos/{REPO}/issues" + params = {"labels": "pytest-failure", "state": "open"} + response = requests.get(url, headers=HEADERS, params=params) + response.raise_for_status() + return response.json() + +def create_issue(test, package): + """Create a new GitHub issue for the test failure.""" + groq_suggestion = ask_groq_for_fix(test["name"], test["message"], test["message"]) + url = f"https://api.github.com/repos/{REPO}/issues" + + # Construct the issue body + data = { + "title": f"[Test Case Failure]: {test['name']}", + "body": f""" +### Test Case: +{test['path']} + +### Failure Details: +{test['message']} + +--- + +### Suggested Fix (via Groq): +{groq_suggestion} + +--- + +### Context: +- **Branch**: [{BASE_BRANCH}](https://github.com/{REPO}/tree/{BASE_BRANCH}) +- **Commit**: [{COMMIT_SHA}](https://github.com/{REPO}/commit/{COMMIT_SHA}) +- **Commit Tree**: [{COMMIT_SHA}](https://github.com/{REPO}/tree/{COMMIT_SHA}) +- **Workflow Run**: [View Run]({WORKFLOW_RUN_URL}) +- **Matrix Package**: `{package}` + +### Labels: +This issue is auto-labeled for the `{package}` package. +""", + "labels": ["pytest-failure", package] + } + response = requests.post(url, headers=HEADERS, json=data) + response.raise_for_status() + print(f"Issue created for {test['name']} with Groq suggestion in package '{package}'.") + +def add_comment_to_issue(issue_number, test, package): + """Add a comment to an existing GitHub issue.""" + url = f"https://api.github.com/repos/{REPO}/issues/{issue_number}/comments" + data = {"body": f""" +New failure detected: + +### Test Case: +{test['path']} + +### Details: +{test['message']} + +--- + +### Context: +- **Branch**: [{BASE_BRANCH}](https://github.com/{REPO}/tree/{BASE_BRANCH}) +- **Commit**: [{COMMIT_SHA}](https://github.com/{REPO}/commit/{COMMIT_SHA}) +- **Commit Tree**: [{COMMIT_SHA}](https://github.com/{REPO}/tree/{COMMIT_SHA}) +- **Workflow Run**: [View Run]({WORKFLOW_RUN_URL}) +- **Matrix Package**: `{package}` +"""} + response = requests.post(url, headers=HEADERS, json=data) + response.raise_for_status() + print(f"Comment added to issue {issue_number} for {test['name']}.") + +def process_failures(report_file, package): + """Process pytest failures and manage GitHub issues.""" + failures = load_pytest_results(report_file) + if not failures: + print("No test failures found.") + return + + existing_issues = get_existing_issues() + + for test in failures: + issue_exists = False + for issue in existing_issues: + if test["name"] in issue["title"]: + add_comment_to_issue(issue["number"], test, package) + issue_exists = True + break + + if not issue_exists: + create_issue(test, package) + +if __name__ == "__main__": + args = parse_arguments() + process_failures(args.results_file, args.package) diff --git a/scripts/rag_issue_manager.py b/scripts/rag_issue_manager.py new file mode 100644 index 000000000..c6bc936e3 --- /dev/null +++ b/scripts/rag_issue_manager.py @@ -0,0 +1,216 @@ +import os +import json +import requests +from swarmauri.utils.load_documents_from_folder import load_documents_from_folder +from swarmauri.llms.concrete.DeepInfraModel import DeepInfraModel +from swarmauri.agents.concrete.RagAgent import RagAgent +from swarmauri.conversations.concrete.MaxSystemContextConversation import MaxSystemContextConversation +from swarmauri.vector_stores.concrete.TfidfVectorStore import TfidfVectorStore +import argparse + +# GitHub API settings +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +REPO = os.getenv("REPO") +HEADERS = { + "Authorization": f"Bearer {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +} + +# GitHub Metadata +BASE_BRANCH = os.getenv("GITHUB_BASE_REF", "unknown") +HEAD_BRANCH = os.getenv("GITHUB_HEAD_REF") or os.getenv("GITHUB_REF", "unknown").split("/")[-1] +COMMIT_SHA = os.getenv("GITHUB_SHA", "unknown") +WORKFLOW_RUN_URL = os.getenv("GITHUB_SERVER_URL", "https://github.com") + f"/{REPO}/actions/runs/{os.getenv('GITHUB_RUN_ID', 'unknown')}" + +# DeepInfraModel Initialization +DEEPINFRA_API_KEY = os.getenv("DEEPINFRA_API_KEY") +llm = DeepInfraModel(api_key=DEEPINFRA_API_KEY, name="meta-llama/Meta-Llama-3.1-8B-Instruct") + +def parse_arguments(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Manage GitHub Issues for Test Failures") + parser.add_argument("--results-file", type=str, required=True, help="Path to the pytest JSON report file") + parser.add_argument("--package", type=str, required=True, help="Name of the matrix package where the error occurred.") + return parser.parse_args() + +def load_pytest_results(report_file): + """Load pytest results from the JSON report.""" + if not os.path.exists(report_file): + print(f"Report file not found: {report_file}") + return [] + + with open(report_file, "r") as f: + report = json.load(f) + + failures = [] + for test in report.get("tests", []): + if test.get("outcome") == "failed": + test_name = test.get("nodeid", "Unknown test") + failure_message = test.get("call", {}).get("longrepr", "No failure details available") + failures.append({ + "name": test_name, + "path": test.get("nodeid", "Unknown path"), + "message": failure_message + }) + return failures + +def ask_agent_for_fix(test_name, failure_message, stack_trace): + """Ask the LLM Model for suggestions on fixing the test failure.""" + system_context = f"Utilizing the following codebase solve the user's problem.\nCodebase:\n\n" + prompt = f""" + \n\nUser Problem: + I have a failing pytest test case named '{test_name}' in a Python project. The error message is: + `{failure_message}` + \n\n + The stack trace is: + `{stack_trace}` + \n\n + Can you help me identify the cause of this failure and suggest a fix? + """ + try: + current_directory = os.getcwd() + documents = load_documents_from_folder(folder_path=current_directory, + include_extensions=['py', 'md', 'yaml', 'toml', 'json']) + print(f"Loaded {len(documents)} documents.") + + # Step 3: Initialize the TFIDF Vector Store and add documents + vector_store = TfidfVectorStore() + vector_store.add_documents(documents) + print("Documents have been added to the TFIDF Vector Store.") + + # Step 5: Initialize the RagAgent with the vector store and language model + rag_agent = RagAgent(system_context=system_context, + llm=llm, vector_store=vector_store, conversation=MaxSystemContextConversation()) + print("RagAgent initialized successfully.") + response = rag_agent.exec(input_data=prompt, top_k=20, llm_kwargs={"max_tokens": 750, "temperature":0.7}) + print(f"\nPrompt: \n{prompt}\n\n", '-'*10, f"Response: {response}\n\n", '='*10, '\n') + return response + except Exception as e: + print(f"Error communicating with LLM: {e}") + return "Unable to retrieve suggestions from LLM at this time." + +def get_existing_issues(): + """Retrieve all existing issues with the pytest-failure label.""" + url = f"https://api.github.com/repos/{REPO}/issues" + params = {"labels": "pytest-failure", "state": "open"} + response = requests.get(url, headers=HEADERS, params=params) + response.raise_for_status() + return response.json() + +def create_issue(test, package): + """Create a new GitHub issue for the test failure.""" + suggestion = ask_agent_for_fix(test["name"], test["message"], test["message"]) + url = f"https://api.github.com/repos/{REPO}/issues" + if package in {'core', 'community', 'experimental'}: + package_name = f"swarmauri_{package}" + else: + package_name = "swarmauri" + resource_kind, type_kind = test['name'].split('/')[2:4] + comp_file_url = f"pkgs/{package}/{package_name}/{resource_kind}/concrete/{type_kind.split('_')[0]}.py" + test_file_url = f"pkgs/{package}/{test['name'].split('::')[0]}" + + # Construct the issue body + data = { + "title": f"[Test Case Failure]: {test['name']}", + "body": f""" +### Test Case: +`{test['path']}` + +### Failure Details: +```python +{test['message']} +``` + +--- + +### Suggested Fix (via Agent): +{suggestion} + +--- + +### Context: +- **Commit**: `{COMMIT_SHA}` +- **Matrix Package**: [{package}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/pkgs/{package}) +- **Component**: [{comp_file_url}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/{comp_file_url}) +- **Test File**: [{test_file_url}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/{test_file_url}) +- **Branches**: [{HEAD_BRANCH}](https://github.com/{REPO}/tree/{HEAD_BRANCH}) ==> [{BASE_BRANCH}](https://github.com/{REPO}/tree/{BASE_BRANCH}) +- **Changes**: [View Changes](https://github.com/{REPO}/commit/{COMMIT_SHA}) +- **Files**: [View Files](https://github.com/{REPO}/tree/{COMMIT_SHA}) +- **Workflow Run**: [View Run]({WORKFLOW_RUN_URL}) + +### Labels: +This issue is auto-labeled for the `{package}` package. +""", + "labels": ["pytest-failure", package] + } + response = requests.post(url, headers=HEADERS, json=data) + response.raise_for_status() + print(f"Issue created for {test['name']} with LLM suggestion in package '{package}'.") + +def add_comment_to_issue(issue_number, test, package): + """Add a comment to an existing GitHub issue.""" + suggestion = ask_agent_for_fix(test["name"], test["message"], test["message"]) + url = f"https://api.github.com/repos/{REPO}/issues/{issue_number}/comments" + if package in {'core', 'community', 'experimental'}: + package_name = f"swarmauri_{package}" + else: + package_name = "swarmauri" + resource_kind, type_kind = test['name'].split('/')[2:4] + comp_file_url = f"pkgs/{package}/{package_name}/{resource_kind}/concrete/{type_kind.split('_')[0]}.py" + test_file_url = f"pkgs/{package}/{test['name'].split('::')[0]}" + data = {"body": f""" +New failure detected: + +### Test Case: +`{test['path']}` + +### Failure Details: +```python +{test['message']} +``` + +--- + +### Suggested Fix (via Agent): +{suggestion} + +--- + +### Context: +- **Commit**: `{COMMIT_SHA}` +- **Matrix Package**: [{package}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/pkgs/{package}) +- **Component**: [{comp_file_url}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/{comp_file_url}) +- **Test File**: [{test_file_url}](https://github.com/{REPO}/tree/{HEAD_BRANCH}/{test_file_url}) +- **Branches**: [{HEAD_BRANCH}](https://github.com/{REPO}/tree/{HEAD_BRANCH}) ==> [{BASE_BRANCH}](https://github.com/{REPO}/tree/{BASE_BRANCH}) +- **Changes**: [View Changes](https://github.com/{REPO}/commit/{COMMIT_SHA}) +- **Files**: [View Files](https://github.com/{REPO}/tree/{COMMIT_SHA}) +- **Workflow Run**: [View Run]({WORKFLOW_RUN_URL}) + +"""} + response = requests.post(url, headers=HEADERS, json=data) + response.raise_for_status() + print(f"Comment added to issue {issue_number} for {test['name']}.") + +def process_failures(report_file, package): + """Process pytest failures and manage GitHub issues.""" + failures = load_pytest_results(report_file) + if not failures: + print("No test failures found.") + return + + existing_issues = get_existing_issues() + + for test in failures: + issue_exists = False + for issue in existing_issues: + if test["name"] in issue["title"]: + add_comment_to_issue(issue["number"], test, package) + issue_exists = True + break + + if not issue_exists: + create_issue(test, package) + +if __name__ == "__main__": + args = parse_arguments() + process_failures(args.results_file, args.package)