diff --git a/embedchain/bots/__init__.py b/embedchain/bots/__init__.py index b7f4d5a007..34cef58f26 100644 --- a/embedchain/bots/__init__.py +++ b/embedchain/bots/__init__.py @@ -1,4 +1,5 @@ -from embedchain.bots.poe import PoeBot -from embedchain.bots.whatsapp import WhatsAppBot +from embedchain.bots.poe import PoeBot # noqa: F401 +from embedchain.bots.whatsapp import WhatsAppBot # noqa: F401 + # TODO: fix discord import -# from embedchain.bots.discord import DiscordBot \ No newline at end of file +# from embedchain.bots.discord import DiscordBot diff --git a/embedchain/bots/whatsapp.py b/embedchain/bots/whatsapp.py index cd9d1eaf67..4cdd1a94c9 100644 --- a/embedchain/bots/whatsapp.py +++ b/embedchain/bots/whatsapp.py @@ -1,4 +1,5 @@ import argparse +import importlib import logging import signal import sys @@ -11,8 +12,14 @@ @register_deserializable class WhatsAppBot(BaseBot): def __init__(self): - from flask import Flask, request - from twilio.twiml.messaging_response import MessagingResponse + try: + self.flask = importlib.import_module("flask") + self.twilio = importlib.import_module("twilio") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for WhatsApp are not installed. " + 'Please install with `pip install --upgrade "embedchain[whatsapp]"`' + ) from None super().__init__() def handle_message(self, message): @@ -41,7 +48,7 @@ def ask_bot(self, message): return response def start(self, host="0.0.0.0", port=5000, debug=True): - app = Flask(__name__) + app = self.flask.Flask(__name__) def signal_handler(sig, frame): logging.info("\nGracefully shutting down the WhatsAppBot...") @@ -51,9 +58,9 @@ def signal_handler(sig, frame): @app.route("/chat", methods=["POST"]) def chat(): - incoming_message = request.values.get("Body", "").lower() + incoming_message = self.flask.request.values.get("Body", "").lower() response = self.handle_message(incoming_message) - twilio_response = MessagingResponse() + twilio_response = self.twilio.twiml.messaging_response.MessagingResponse() twilio_response.message(response) return str(twilio_response) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 38f12d734c..e35d7eb4df 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -82,7 +82,7 @@ def __init__( # Send anonymous telemetry self.s_id = self.config.id if self.config.id else str(uuid.uuid4()) self.u_id = self._load_or_generate_user_id() - # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. + # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. # if (self.config.collect_metrics): # raise ConnectionRefusedError("Collection of metrics should not be allowed.") thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",)) diff --git a/embedchain/llm/antrophic_llm.py b/embedchain/llm/antrophic_llm.py index b996e2be1e..90950fa6d2 100644 --- a/embedchain/llm/antrophic_llm.py +++ b/embedchain/llm/antrophic_llm.py @@ -2,9 +2,8 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/embedchain/llm/azure_openai_llm.py b/embedchain/llm/azure_openai_llm.py index 4ced9ca16e..40703446a3 100644 --- a/embedchain/llm/azure_openai_llm.py +++ b/embedchain/llm/azure_openai_llm.py @@ -2,9 +2,8 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/embedchain/llm/base_llm.py b/embedchain/llm/base_llm.py index 4c25173dde..ef49239237 100644 --- a/embedchain/llm/base_llm.py +++ b/embedchain/llm/base_llm.py @@ -3,12 +3,12 @@ from langchain.memory import ConversationBufferMemory from langchain.schema import BaseMessage -from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.config import BaseLlmConfig from embedchain.config.llm.base_llm_config import ( DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE) +from embedchain.helper_classes.json_serializable import JSONSerializable class BaseLlm(JSONSerializable): diff --git a/embedchain/llm/gpt4all_llm.py b/embedchain/llm/gpt4all_llm.py index 1624ae9f8a..36ea325b5c 100644 --- a/embedchain/llm/gpt4all_llm.py +++ b/embedchain/llm/gpt4all_llm.py @@ -1,9 +1,8 @@ from typing import Iterable, Optional, Union from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/embedchain/llm/llama2_llm.py b/embedchain/llm/llama2_llm.py index 6a2d90a6e2..42ec1eff2f 100644 --- a/embedchain/llm/llama2_llm.py +++ b/embedchain/llm/llama2_llm.py @@ -4,9 +4,8 @@ from langchain.llms import Replicate from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/embedchain/llm/openai_llm.py b/embedchain/llm/openai_llm.py index 320079f756..f22b271fd9 100644 --- a/embedchain/llm/openai_llm.py +++ b/embedchain/llm/openai_llm.py @@ -3,9 +3,8 @@ import openai from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/embedchain/llm/vertex_ai_llm.py b/embedchain/llm/vertex_ai_llm.py index b1d47ad6a6..a5a9927a0f 100644 --- a/embedchain/llm/vertex_ai_llm.py +++ b/embedchain/llm/vertex_ai_llm.py @@ -2,9 +2,8 @@ from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.llm.base_llm import BaseLlm - from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm @register_deserializable diff --git a/tests/llm/test_chat.py b/tests/llm/test_chat.py index cb50a89533..a5340b0798 100644 --- a/tests/llm/test_chat.py +++ b/tests/llm/test_chat.py @@ -1,7 +1,6 @@ - import os import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig @@ -88,8 +87,8 @@ def test_chat_with_where_in_params(self): self.assertEqual(answer, "Test answer") _args, kwargs = mock_retrieve.call_args - self.assertEqual(kwargs.get('input_query'), "Test query") - self.assertEqual(kwargs.get('where'), {"attribute": "value"}) + self.assertEqual(kwargs.get("input_query"), "Test query") + self.assertEqual(kwargs.get("where"), {"attribute": "value"}) mock_answer.assert_called_once() @patch("chromadb.api.models.Collection.Collection.add", MagicMock) @@ -120,6 +119,6 @@ def test_chat_with_where_in_chat_config(self): self.assertEqual(answer, "Test answer") _args, kwargs = mock_database_query.call_args - self.assertEqual(kwargs.get('input_query'), "Test query") - self.assertEqual(kwargs.get('where'), {"attribute": "value"}) + self.assertEqual(kwargs.get("input_query"), "Test query") + self.assertEqual(kwargs.get("where"), {"attribute": "value"}) mock_answer.assert_called_once() diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py index 55bbb7660a..e16ebc90bc 100644 --- a/tests/llm/test_query.py +++ b/tests/llm/test_query.py @@ -109,8 +109,8 @@ def test_query_with_where_in_params(self): self.assertEqual(answer, "Test answer") _args, kwargs = mock_retrieve.call_args - self.assertEqual(kwargs.get('input_query'), "Test query") - self.assertEqual(kwargs.get('where'), {"attribute": "value"}) + self.assertEqual(kwargs.get("input_query"), "Test query") + self.assertEqual(kwargs.get("where"), {"attribute": "value"}) mock_answer.assert_called_once() @patch("chromadb.api.models.Collection.Collection.add", MagicMock) @@ -142,6 +142,6 @@ def test_query_with_where_in_query_config(self): self.assertEqual(answer, "Test answer") _args, kwargs = mock_database_query.call_args - self.assertEqual(kwargs.get('input_query'), "Test query") - self.assertEqual(kwargs.get('where'), {"attribute": "value"}) + self.assertEqual(kwargs.get("input_query"), "Test query") + self.assertEqual(kwargs.get("where"), {"attribute": "value"}) mock_answer.assert_called_once() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 3188289bcf..91a386bbec 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -7,7 +7,6 @@ from embedchain import App from embedchain.config import AppConfig, ChromaDbConfig -from embedchain.models import EmbeddingFunctions, Providers from embedchain.vectordb.chroma_db import ChromaDB @@ -86,7 +85,6 @@ def test_init_with_host_and_port_log_level(self, mock_client): """ Test if the `App` instance is initialized without a config that does not contain default hosts and ports. """ - config = AppConfig(log_level="DEBUG") _app = App(config=AppConfig(collect_metrics=False))