Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint and formatting fixes #554

Merged
merged 11 commits into from
Sep 5, 2023
7 changes: 4 additions & 3 deletions embedchain/bots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from embedchain.bots.poe import PoeBot
from embedchain.bots.whatsapp import WhatsAppBot
from embedchain.bots.poe import PoeBot # noqa: F401
from embedchain.bots.whatsapp import WhatsAppBot # noqa: F401

# TODO: fix discord import
# from embedchain.bots.discord import DiscordBot
# from embedchain.bots.discord import DiscordBot
17 changes: 12 additions & 5 deletions embedchain/bots/whatsapp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import importlib
import logging
import signal
import sys
Expand All @@ -11,8 +12,14 @@
@register_deserializable
class WhatsAppBot(BaseBot):
def __init__(self):
from flask import Flask, request
from twilio.twiml.messaging_response import MessagingResponse
try:
self.flask = importlib.import_module("flask")
self.twilio = importlib.import_module("twilio")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for WhatsApp are not installed. "
'Please install with `pip install --upgrade "embedchain[whatsapp]"`'
) from None
super().__init__()

def handle_message(self, message):
Expand Down Expand Up @@ -41,7 +48,7 @@ def ask_bot(self, message):
return response

def start(self, host="0.0.0.0", port=5000, debug=True):
app = Flask(__name__)
app = self.flask.Flask(__name__)

def signal_handler(sig, frame):
logging.info("\nGracefully shutting down the WhatsAppBot...")
Expand All @@ -51,9 +58,9 @@ def signal_handler(sig, frame):

@app.route("/chat", methods=["POST"])
def chat():
incoming_message = request.values.get("Body", "").lower()
incoming_message = self.flask.request.values.get("Body", "").lower()
response = self.handle_message(incoming_message)
twilio_response = MessagingResponse()
twilio_response = self.twilio.twiml.messaging_response.MessagingResponse()
twilio_response.message(response)
return str(twilio_response)

Expand Down
2 changes: 1 addition & 1 deletion embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
# Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/antrophic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Optional

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/azure_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Optional

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
2 changes: 1 addition & 1 deletion embedchain/llm/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
from embedchain.helper_classes.json_serializable import JSONSerializable

from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base_llm_config import (
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper_classes.json_serializable import JSONSerializable


class BaseLlm(JSONSerializable):
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/gpt4all_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Iterable, Optional, Union

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/llama2_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from langchain.llms import Replicate

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import openai

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
3 changes: 1 addition & 2 deletions embedchain/llm/vertex_ai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Optional

from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm

from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm


@register_deserializable
Expand Down
11 changes: 5 additions & 6 deletions tests/llm/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

import os
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
Expand Down Expand Up @@ -88,8 +87,8 @@ def test_chat_with_where_in_params(self):

self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()

@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
Expand Down Expand Up @@ -120,6 +119,6 @@ def test_chat_with_where_in_chat_config(self):

self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
8 changes: 4 additions & 4 deletions tests/llm/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_query_with_where_in_params(self):

self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()

@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
Expand Down Expand Up @@ -142,6 +142,6 @@ def test_query_with_where_in_query_config(self):

self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
2 changes: 0 additions & 2 deletions tests/vectordb/test_chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB


Expand Down Expand Up @@ -86,7 +85,6 @@ def test_init_with_host_and_port_log_level(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
config = AppConfig(log_level="DEBUG")

_app = App(config=AppConfig(collect_metrics=False))

Expand Down