diff --git a/.gitignore b/.gitignore index 116ddb6e9..69391cda8 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ docs/_build/ examples/settings.py examples/ubuntu_dialogs* +.env +.out +venv* diff --git a/chatterbot/storage/__init__.py b/chatterbot/storage/__init__.py index 2fe7e08b9..2efe59951 100644 --- a/chatterbot/storage/__init__.py +++ b/chatterbot/storage/__init__.py @@ -2,3 +2,4 @@ from .django_storage import DjangoStorageAdapter from .jsonfile import JsonFileStorageAdapter from .mongodb import MongoDatabaseAdapter +from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter diff --git a/chatterbot/storage/sqlalchemy_storage.py b/chatterbot/storage/sqlalchemy_storage.py new file mode 100644 index 000000000..1c2fcbeb7 --- /dev/null +++ b/chatterbot/storage/sqlalchemy_storage.py @@ -0,0 +1,273 @@ +import json +import random + +from chatterbot.storage import StorageAdapter +from chatterbot.conversation import Response +from chatterbot.conversation import Statement + +_base = None + +try: + from sqlalchemy.ext.declarative import declarative_base + + _base = declarative_base() + + + class StatementTable(_base): + from sqlalchemy import Column, Integer, String, PickleType + from sqlalchemy.orm import relationship + + __tablename__ = 'StatementTable' + + def get_statement(self): + stmt = Statement(self.text, **self.extra_data) + for resp in self.in_response_to: + stmt.add_response(resp.get_response()) + return stmt + + def get_statement_serialized(context): + params = context.current_parameters + del (params['text_search']) + return json.dumps(params) + + id = Column(Integer) + text = Column(String, primary_key=True) + extra_data = Column(PickleType) + # relationship: + in_response_to = relationship("ResponseTable", back_populates="statement_table") + text_search = Column(String, primary_key=True, default=get_statement_serialized) + + + class ResponseTable(_base): + from sqlalchemy import Column, Integer, String, ForeignKey + from sqlalchemy.orm import relationship + __tablename__ = 'ResponseTable' + + def get_reponse_serialized(context): + params = context.current_parameters + del (params['text_search']) + return json.dumps(params) + + id = Column(Integer) + text = Column(String, primary_key=True) + occurrence = Column(Integer) + statement_text = Column(String, ForeignKey('StatementTable.text')) + + statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False) + text_search = Column(String, primary_key=True, default=get_reponse_serialized) + + def get_response(self): + occ = {"occurrence": self.occurrence} + return Response(text=self.text, **occ) + +except ImportError: + pass + + +def get_statement_table(statement): + responses = list(map(get_response_table, statement.in_response_to)) + return StatementTable(text=statement.text, in_response_to=responses, extra_data=statement.extra_data) + + +def get_response_table(response): + return ResponseTable(text=response.text, occurrence=response.occurrence) + + +class SQLAlchemyDatabaseAdapter(StorageAdapter): + read_only = False + drop_create = False + + def __init__(self, **kwargs): + super(SQLAlchemyDatabaseAdapter, self).__init__(**kwargs) + + from sqlalchemy import create_engine + + self.database_name = self.kwargs.get( + "database", "chatterbot-database" + ) + + # if some annoying blank space wrong... + db_name = self.database_name.strip() + + # default uses sqlite + self.database_uri = self.kwargs.get( + "database_uri", "sqlite:///" + db_name + ".db" + ) + + self.engine = create_engine(self.database_uri) + + self.read_only = self.kwargs.get( + "read_only", False + ) + + self.drop_create = self.kwargs.get( + "drop_create", False + ) + + if not self.read_only and self.drop_create: + _base.metadata.drop_all(self.engine) + _base.metadata.create_all(self.engine) + + def count(self): + """ + Return the number of entries in the database. + """ + session = self.__get_session() + return session.query(StatementTable).count() + + def __get_session(self): + """ + :rtype: Session + """ + from sqlalchemy.orm import sessionmaker + + Session = sessionmaker(bind=self.engine) + session = Session() + return session + + def __statement_filter(self, session, **kwargs): + """ + Apply filter operation on StatementTable + + rtype: query + """ + _query = session.query(StatementTable) + return _query.filter_by(**kwargs) + + def find(self, statement_text): + """ + Returns a statement if it exists otherwise None + """ + session = self.__get_session() + query = self.__statement_filter(session, **{"text": statement_text}) + record = query.first() + if record: + return record.get_statement() + return None + + def remove(self, statement_text): + """ + Removes the statement that matches the input text. + Removes any responses from statements where the response text matches + the input text. + """ + session = self.__get_session() + query = self.__statement_filter(session, **{"text": statement_text}) + record = query.first() + session.delete(record) + + self._session_finish(session, statement_text) + + def filter(self, **kwargs): + """ + Returns a list of objects from the database. + The kwargs parameter can contain any number + of attributes. Only objects which contain + all listed attributes and in which all values + match for all listed attributes will be returned. + """ + + filter_parameters = kwargs.copy() + + session = self.__get_session() + statements = [] + # _response_query = None + _query = None + if len(filter_parameters) == 0: + _response_query = session.query(StatementTable) + statements.extend(_response_query.all()) + else: + for i, fp in enumerate(filter_parameters): + _filter = filter_parameters[fp] + if fp in ['in_response_to', 'in_response_to__contains']: + _response_query = session.query(StatementTable) + if isinstance(_filter, list): + if len(_filter) == 0: + _query = _response_query.filter( + StatementTable.in_response_to == None) # Here must use == instead of is + else: + for f in _filter: + _query = _response_query.filter( + StatementTable.in_response_to.contains(get_response_table(f))) + else: + if fp == 'in_response_to__contains': + _query = _response_query.join(ResponseTable).filter(ResponseTable.text == _filter) + else: + _query = _response_query.filter(StatementTable.in_response_to == None) + else: + if _query: + _query = _query.filter(ResponseTable.text_search.like('%' + _filter + '%')) + else: + _response_query = session.query(ResponseTable) + _query = _response_query.filter(ResponseTable.text_search.like('%' + _filter + '%')) + + if _query is None: + return [] + if len(filter_parameters) == i + 1: + statements.extend(_query.all()) + + results = [] + + for statement in statements: + if isinstance(statement, ResponseTable): + if statement and statement.statement_table: + results.append(statement.statement_table.get_statement()) + else: + if statement: + results.append(statement.get_statement()) + + return results + + def update(self, statement): + """ + Modifies an entry in the database. + Creates an entry if one does not exist. + """ + session = self.__get_session() + if statement: + query = self.__statement_filter(session, **{"text": statement.text}) + record = query.first() + + if record: + # update + if statement.text: + record.text = statement.text + if statement.extra_data: + record.extra_data = dict[statement.extra_data] + if statement.in_response_to: + record.in_response_to = list(map(get_response_table, statement.in_response_to)) + session.add(record) + else: + session.add(get_statement_table(statement)) + + self._session_finish(session) + + def get_random(self): + """ + Returns a random statement from the database + """ + count = self.count() + if count < 1: + raise self.EmptyDatabaseException() + + rand = random.randrange(0, count) + session = self.__get_session() + stmt = session.query(StatementTable)[rand] + + return stmt.get_statement() + + def drop(self): + """ + Drop the database attached to a given adapter. + """ + _base.metadata.drop_all(self.engine) + + def _session_finish(self, session, statement_text=None): + from sqlalchemy.exc import DatabaseError + try: + if not self.read_only: + session.commit() + else: + session.rollback() + except DatabaseError as e: + self.logger.error(statement_text, str(e.orig)) diff --git a/requirements.txt b/requirements.txt index 27b69d708..8f28c03b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ jsondatabase>=0.1.7,<1.0.0 nltk>=3.2.0,<4.0.0 pymongo>=3.3.0,<4.0.0 python-twitter>=3.0.0,<4.0.0 +SQLAlchemy==1.1.7 \ No newline at end of file diff --git a/test-requirements.txt b/test-requirements.txt index 9dce2d73c..8cb31b5c1 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -7,4 +7,4 @@ nose nose-exclude>=0.5.0,<0.6.0 twython sphinx -sphinx_rtd_theme +sphinx_rtd_theme \ No newline at end of file diff --git a/tests/base_case.py b/tests/base_case.py index d2f0eed2a..4a1947c9e 100644 --- a/tests/base_case.py +++ b/tests/base_case.py @@ -62,4 +62,4 @@ def get_kwargs(self): kwargs = super(ChatBotMongoTestCase, self).get_kwargs() kwargs['database'] = self.random_string() kwargs['storage_adapter'] = 'chatterbot.storage.MongoDatabaseAdapter' - return kwargs + return kwargs \ No newline at end of file diff --git a/tests/storage_adapter_tests/integration_tests/sqlalchemy_integration_tests.py b/tests/storage_adapter_tests/integration_tests/sqlalchemy_integration_tests.py new file mode 100644 index 000000000..aa8a50078 --- /dev/null +++ b/tests/storage_adapter_tests/integration_tests/sqlalchemy_integration_tests.py @@ -0,0 +1,17 @@ +from tests.base_case import ChatBotTestCase + + +class SqlAlchemyStorageIntegrationTests(ChatBotTestCase): + + def test_database_is_updated(self): + """ + Test that the database is updated when read_only is set to false. + """ + input_text = 'What is the airspeed velocity of an unladen swallow?' + exists_before = self.chatbot.storage.find(input_text) + + response = self.chatbot.get_response(input_text) + exists_after = self.chatbot.storage.find(input_text) + + self.assertFalse(exists_before) + self.assertTrue(exists_after) diff --git a/tests/storage_adapter_tests/test_sqlalchemy_adapter.py b/tests/storage_adapter_tests/test_sqlalchemy_adapter.py new file mode 100644 index 000000000..3cd781726 --- /dev/null +++ b/tests/storage_adapter_tests/test_sqlalchemy_adapter.py @@ -0,0 +1,362 @@ +from unittest import TestCase + +from chatterbot.conversation import Statement, Response +from chatterbot.storage.sqlalchemy_storage import SQLAlchemyDatabaseAdapter + + +class SQLAlchemyAdapterTestCase(TestCase): + def setUp(self): + """ + Instantiate the adapter. + """ + from random import randint + + # Generate a random name for the database + database_name = str(randint(0, 9000)) + + self.adapter = SQLAlchemyDatabaseAdapter( + database='sqlite_' + database_name, + drop_create=True + ) + + +class SQLAlchemyDatabaseAdapterTestCase(SQLAlchemyAdapterTestCase): + def test_count_returns_zero(self): + """ + The count method should return a value of 0 + when nothing has been saved to the database. + """ + self.assertEqual(self.adapter.count(), 0) + + def test_count_returns_value(self): + """ + The count method should return a value of 1 + when one item has been saved to the database. + """ + statement = Statement("Test statement") + self.adapter.update(statement) + self.assertEqual(self.adapter.count(), 1) + + def test_statement_not_found(self): + """ + Test that None is returned by the find method + when a matching statement is not found. + """ + self.assertIsNone(self.adapter.find("Non-existant")) + + def test_statement_found(self): + """ + Test that a matching statement is returned + when it exists in the database. + """ + statement = Statement("New statement") + self.adapter.update(statement) + + found_statement = self.adapter.find("New statement") + self.assertIsNotNone(found_statement) + self.assertEqual(found_statement.text, statement.text) + + def test_update_adds_new_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertIsNotNone(statement_found) + self.assertEqual(statement_found.text, statement.text) + + def test_update_modifies_existing_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + # Check the initial values + found_statement = self.adapter.find(statement.text) + self.assertEqual( + len(found_statement.in_response_to), 0 + ) + + # Update the statement value + statement.add_response( + Response("New response") + ) + self.adapter.update(statement) + + # Check that the values have changed + found_statement = self.adapter.find(statement.text) + self.assertEqual( + len(found_statement.in_response_to), 1 + ) + + def test_get_random_returns_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + random_statement = self.adapter.get_random() + self.assertEqual(random_statement.text, statement.text) + + def test_find_returns_nested_responses(self): + response_list = [ + Response("Yes"), + Response("No") + ] + statement = Statement( + "Do you like this?", + in_response_to=response_list + ) + self.adapter.update(statement) + + result = self.adapter.find(statement.text) + + self.assertIn("Yes", result.in_response_to) + self.assertIn("No", result.in_response_to) + + def test_multiple_responses_added_on_update(self): + statement = Statement( + "You are welcome.", + in_response_to=[ + Response("Thank you."), + Response("Thanks.") + ] + ) + self.adapter.update(statement) + result = self.adapter.find(statement.text) + + self.assertEqual(len(result.in_response_to), 2) + self.assertIn(statement.in_response_to[0], result.in_response_to) + self.assertIn(statement.in_response_to[1], result.in_response_to) + + def test_update_saves_statement_with_multiple_responses(self): + statement = Statement( + "You are welcome.", + in_response_to=[ + Response("Thank you."), + Response("Thanks."), + ] + ) + self.adapter.update(statement) + response = self.adapter.find(statement.text) + + self.assertEqual(len(response.in_response_to), 2) + + def test_getting_and_updating_statement(self): + statement = Statement("Hi") + self.adapter.update(statement) + + statement.add_response(Response("Hello")) + statement.add_response(Response("Hello")) + self.adapter.update(statement) + + response = self.adapter.find(statement.text) + + self.assertEqual(len(response.in_response_to), 1) + self.assertEqual(response.in_response_to[0].occurrence, 2) + + def test_remove(self): + text = "Sometimes you have to run before you can walk." + statement = Statement(text) + self.adapter.update(statement) + self.adapter.remove(statement.text) + result = self.adapter.find(text) + + self.assertIsNone(result) + + def test_remove_response(self): + text = "Sometimes you have to run before you can walk." + statement = Statement( + "A test flight is not recommended at this design phase.", + in_response_to=[Response(text)] + ) + self.adapter.update(statement) + self.adapter.remove(statement.text) + results = self.adapter.filter(in_response_to__contains=text) + + self.assertEqual(results, []) + + def test_get_response_statements(self): + """ + Test that we are able to get a list of only statements + that are known to be in response to another statement. + """ + statement_list = [ + Statement("What... is your quest?"), + Statement("This is a phone."), + Statement("A what?", in_response_to=[Response("This is a phone.")]), + Statement("A phone.", in_response_to=[Response("A what?")]) + ] + + for statement in statement_list: + self.adapter.update(statement) + + responses = self.adapter.get_response_statements() + + self.assertEqual(len(responses), 2) + self.assertIn("This is a phone.", responses) + self.assertIn("A what?", responses) + + +class SQLAlchemyStorageAdapterFilterTestCase(SQLAlchemyAdapterTestCase): + def setUp(self): + super(SQLAlchemyStorageAdapterFilterTestCase, self).setUp() + + self.statement1 = Statement( + "Testing...", + in_response_to=[ + Response("Why are you counting?") + ] + ) + self.statement2 = Statement( + "Testing one, two, three.", + in_response_to=[ + Response("Testing...") + ] + ) + + def test_filter_text_no_matches(self): + self.adapter.update(self.statement1) + results = self.adapter.filter(text="Howdy") + + self.assertEqual(len(results), 0) + + def test_filter_in_response_to_no_matches(self): + self.adapter.update(self.statement1) + + results = self.adapter.filter( + in_response_to=[Response("Maybe")] + ) + self.assertEqual(len(results), 0) + + def test_filter_equal_results(self): + statement1 = Statement( + "Testing...", + in_response_to=[] + ) + statement2 = Statement( + "Testing one, two, three.", + in_response_to=[] + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter(in_response_to=[]) + self.assertEqual(len(results), 2) + self.assertIn(statement1, results) + self.assertIn(statement2, results) + + def test_filter_contains_result(self): + self.adapter.update(self.statement1) + self.adapter.update(self.statement2) + + results = self.adapter.filter( + in_response_to__contains="Why are you counting?" + ) + self.assertEqual(len(results), 1) + self.assertIn(self.statement1, results) + + def test_filter_contains_no_result(self): + self.adapter.update(self.statement1) + + results = self.adapter.filter( + in_response_to__contains="How do you do?" + ) + self.assertEqual(results, []) + + def test_filter_multiple_parameters(self): + self.adapter.update(self.statement1) + self.adapter.update(self.statement2) + + results = self.adapter.filter( + text="Testing...", + in_response_to__contains = "Why are you counting?" + ) + + self.assertEqual(len(results), 1) + self.assertIn(self.statement1, results) + + def test_filter_multiple_parameters_no_results(self): + self.adapter.update(self.statement1) + self.adapter.update(self.statement2) + + results = self.adapter.filter( + text="Test", + in_response_to__contains="Not an existing response." + ) + + self.assertEqual(len(results), 0) + + def test_filter_no_parameters(self): + """ + If no parameters are passed to the filter, + then all statements should be returned. + """ + statement1 = Statement("Testing...") + statement2 = Statement("Testing one, two, three.") + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter() + + self.assertEqual(len(results), 2) + + def test_filter_returns_statement_with_multiple_responses(self): + statement = Statement( + "You are welcome.", + in_response_to=[ + Response("Thanks."), + Response("Thank you.") + ] + ) + self.adapter.update(statement) + response = self.adapter.filter( + in_response_to__contains="Thanks." + ) + + # Get the first response + response = response[0] + + self.assertEqual(len(response.in_response_to), 2) + + def test_response_list_in_results(self): + """ + If a statement with response values is found using + the filter method, they should be returned as + response objects. + """ + statement = Statement( + "The first is to help yourself, the second is to help others.", + in_response_to=[ + Response("Why do people have two hands?") + ] + ) + self.adapter.update(statement) + found = self.adapter.filter(text=statement.text) + + self.assertEqual(len(found[0].in_response_to), 1) + self.assertIsInstance(found[0].in_response_to[0], Response) + + +class ReadOnlySQLAlchemyDatabaseAdapterTestCase(SQLAlchemyAdapterTestCase): + def test_update_does_not_add_new_statement(self): + self.adapter.read_only = True + + statement = Statement("New statement") + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertIsNone(statement_found) + + def test_update_does_not_modify_existing_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + self.adapter.read_only = True + + statement.add_response( + Response("New response") + ) + + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertEqual(statement_found.text, statement.text) + self.assertEqual( + len(statement_found.in_response_to), 0 + ) diff --git a/tests/training_tests/test_chatterbot_corpus_training.py b/tests/training_tests/test_chatterbot_corpus_training.py index 5e2d08f4b..7f19c5232 100644 --- a/tests/training_tests/test_chatterbot_corpus_training.py +++ b/tests/training_tests/test_chatterbot_corpus_training.py @@ -69,4 +69,4 @@ def test_train_with_english_corpus_training_slash(self): self.chatbot.train(file_path) statement = self.chatbot.storage.find('Hello') - self.assertIsNotNone(statement) + self.assertIsNotNone(statement) \ No newline at end of file