diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 896607ed1..862ad62f5 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -66,6 +66,9 @@ def __init__(self, name, **kwargs): self.logger = kwargs.get('logger', logging.getLogger(__name__)) + # Allow the bot to save input it receives so that it can learn + self.read_only = kwargs.get('read_only', False) + if kwargs.get('initialize', True): self.initialize() @@ -138,7 +141,8 @@ def learn_response(self, statement, previous_statement): )) # Save the statement after selecting a response - self.storage.update(statement) + if not self.read_only: + self.storage.update(statement) def set_trainer(self, training_class, **kwargs): """ diff --git a/chatterbot/storage/django_storage.py b/chatterbot/storage/django_storage.py index 8a1c0c7b4..15efc8b14 100644 --- a/chatterbot/storage/django_storage.py +++ b/chatterbot/storage/django_storage.py @@ -73,7 +73,7 @@ def filter(self, **kwargs): return statements - def update(self, statement, **kwargs): + def update(self, statement): """ Update the provided statement. """ @@ -81,28 +81,26 @@ def update(self, statement, **kwargs): response_statement_cache = statement.response_statement_cache - # Do not alter the database unless writing is enabled - if not self.read_only: - statement, created = StatementModel.objects.get_or_create(text=statement.text) - statement.extra_data = getattr(statement, 'extra_data', '') - statement.save() + statement, created = StatementModel.objects.get_or_create(text=statement.text) + statement.extra_data = getattr(statement, 'extra_data', '') + statement.save() - for _response_statement in response_statement_cache: + for _response_statement in response_statement_cache: - response_statement, created = StatementModel.objects.get_or_create( - text=_response_statement.text - ) - response_statement.extra_data = getattr(_response_statement, 'extra_data', '') - response_statement.save() + response_statement, created = StatementModel.objects.get_or_create( + text=_response_statement.text + ) + response_statement.extra_data = getattr(_response_statement, 'extra_data', '') + response_statement.save() - response, created = statement.in_response.get_or_create( - statement=statement, - response=response_statement - ) + response, created = statement.in_response.get_or_create( + statement=statement, + response=response_statement + ) - if not created: - response.occurrence += 1 - response.save() + if not created: + response.occurrence += 1 + response.save() return statement diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index 6a7657094..c81ef8941 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -14,10 +14,6 @@ class JsonFileStorageAdapter(StorageAdapter): :keyword silence_performance_warning: If set to True, the :code:`UnsuitableForProductionWarning` will not be displayed. :type silence_performance_warning: bool - - :keyword read_only: If set to True, ChatterBot will not save information to the database. - False by default. - :type read_only: bool """ def __init__(self, **kwargs): @@ -154,24 +150,22 @@ def filter(self, **kwargs): return results - def update(self, statement, **kwargs): + def update(self, statement): """ Update a statement in the database. """ - # Do not alter the database unless writing is enabled - if not self.read_only: - data = statement.serialize() - - # Remove the text key from the data - del data['text'] - self.database.data(key=statement.text, value=data) - - # Make sure that an entry for each response exists - for response_statement in statement.in_response_to: - response = self.find(response_statement.text) - if not response: - response = self.Statement(response_statement.text) - self.update(response) + data = statement.serialize() + + # Remove the text key from the data + del data['text'] + self.database.data(key=statement.text, value=data) + + # Make sure that an entry for each response exists + for response_statement in statement.in_response_to: + response = self.find(response_statement.text) + if not response: + response = self.Statement(response_statement.text) + self.update(response) return statement diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index e3af2759d..e025c9c3e 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -76,11 +76,6 @@ class MongoDatabaseAdapter(StorageAdapter): .. code-block:: python database_uri='mongodb://example.com:8100/' - - - :keyword read_only: If set to True, ChatterBot will not save information to the database. - False by default. - :type read_only: bool """ def __init__(self, **kwargs): @@ -206,41 +201,38 @@ def filter(self, **kwargs): return results - def update(self, statement, **kwargs): + def update(self, statement): from pymongo import UpdateOne from pymongo.errors import BulkWriteError - force = kwargs.get('force', False) - # Do not alter the database unless writing is enabled - if force or not self.read_only: - data = statement.serialize() + data = statement.serialize() + + operations = [] + + update_operation = UpdateOne( + {'text': statement.text}, + {'$set': data}, + upsert=True + ) + operations.append(update_operation) - operations = [] + # Make sure that an entry for each response is saved + for response_dict in data.get('in_response_to', []): + response_text = response_dict.get('text') + # $setOnInsert does nothing if the document is not created update_operation = UpdateOne( - {'text': statement.text}, - {'$set': data}, + {'text': response_text}, + {'$set': response_dict}, upsert=True ) operations.append(update_operation) - # Make sure that an entry for each response is saved - for response_dict in data.get('in_response_to', []): - response_text = response_dict.get('text') - - # $setOnInsert does nothing if the document is not created - update_operation = UpdateOne( - {'text': response_text}, - {'$set': response_dict}, - upsert=True - ) - operations.append(update_operation) - - try: - self.statements.bulk_write(operations, ordered=False) - except BulkWriteError as bwe: - # Log the details of a bulk write error - self.logger.error(str(bwe.details)) + try: + self.statements.bulk_write(operations, ordered=False) + except BulkWriteError as bwe: + # Log the details of a bulk write error + self.logger.error(str(bwe.details)) return statement diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index 0c5fbc99c..decd49a96 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -13,7 +13,6 @@ def __init__(self, base_query=None, *args, **kwargs): """ self.kwargs = kwargs self.logger = kwargs.get('logger', logging.getLogger(__name__)) - self.read_only = kwargs.get('read_only', False) self.adapter_supports_queries = True self.base_query = None diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 9c32fc4e8..51d48e0b1 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -87,7 +87,7 @@ def train(self, conversation): ) statement_history.append(statement) - self.storage.update(statement, force=True) + self.storage.update(statement) class ChatterBotCorpusTrainer(Trainer): @@ -204,7 +204,7 @@ def train(self): for _ in range(0, 10): statements = self.get_statements() for statement in statements: - self.storage.update(statement, force=True) + self.storage.update(statement) class UbuntuCorpusTrainer(Trainer): @@ -345,4 +345,4 @@ def train(self): ) statement_history.append(statement) - self.storage.update(statement, force=True) + self.storage.update(statement) diff --git a/tests/storage_adapter_tests/integration_tests/base.py b/tests/storage_adapter_tests/integration_tests/base.py deleted file mode 100644 index e43d41027..000000000 --- a/tests/storage_adapter_tests/integration_tests/base.py +++ /dev/null @@ -1,29 +0,0 @@ -class StorageIntegrationTests(object): - - def test_database_is_updated(self): - """ - Test that the database is updated when read_only is set to false. - """ - input_text = "What is the airspeed velocity of an unladen swallow?" - exists_before = self.chatbot.storage.find(input_text) - - response = self.chatbot.get_response(input_text) - exists_after = self.chatbot.storage.find(input_text) - - self.assertFalse(exists_before) - self.assertTrue(exists_after) - - def test_database_is_not_updated_when_read_only(self): - """ - Test that the database is not updated when read_only is set to true. - """ - self.chatbot.storage.read_only = True - - input_text = "Who are you, the proud lord said?" - exists_before = self.chatbot.storage.find(input_text) - - response = self.chatbot.get_response(input_text) - exists_after = self.chatbot.storage.find(input_text) - - self.assertFalse(exists_before) - self.assertFalse(exists_after) diff --git a/tests/storage_adapter_tests/integration_tests/json_integration_tests.py b/tests/storage_adapter_tests/integration_tests/json_integration_tests.py index 000fdfa6b..1419163b0 100644 --- a/tests/storage_adapter_tests/integration_tests/json_integration_tests.py +++ b/tests/storage_adapter_tests/integration_tests/json_integration_tests.py @@ -1,6 +1,17 @@ from tests.base_case import ChatBotTestCase -from .base import StorageIntegrationTests -class JsonStorageIntegrationTests(StorageIntegrationTests, ChatBotTestCase): - pass \ No newline at end of file +class JsonStorageIntegrationTests(ChatBotTestCase): + + def test_database_is_updated(self): + """ + Test that the database is updated when read_only is set to false. + """ + input_text = 'What is the airspeed velocity of an unladen swallow?' + exists_before = self.chatbot.storage.find(input_text) + + response = self.chatbot.get_response(input_text) + exists_after = self.chatbot.storage.find(input_text) + + self.assertFalse(exists_before) + self.assertTrue(exists_after) diff --git a/tests/storage_adapter_tests/integration_tests/mongo_integration_tests.py b/tests/storage_adapter_tests/integration_tests/mongo_integration_tests.py index bdec054ca..8bc0b98a3 100644 --- a/tests/storage_adapter_tests/integration_tests/mongo_integration_tests.py +++ b/tests/storage_adapter_tests/integration_tests/mongo_integration_tests.py @@ -1,6 +1,17 @@ from tests.base_case import ChatBotMongoTestCase -from .base import StorageIntegrationTests -class MongoStorageIntegrationTests(StorageIntegrationTests, ChatBotMongoTestCase): - pass \ No newline at end of file +class MongoStorageIntegrationTests(ChatBotMongoTestCase): + + def test_database_is_updated(self): + """ + Test that the database is updated when read_only is set to false. + """ + input_text = 'What is the airspeed velocity of an unladen swallow?' + exists_before = self.chatbot.storage.find(input_text) + + response = self.chatbot.get_response(input_text) + exists_after = self.chatbot.storage.find(input_text) + + self.assertFalse(exists_before) + self.assertTrue(exists_after) diff --git a/tests/storage_adapter_tests/test_json_file_storage_adapter.py b/tests/storage_adapter_tests/test_json_file_storage_adapter.py index ae6bb8a3a..26b0ffce1 100644 --- a/tests/storage_adapter_tests/test_json_file_storage_adapter.py +++ b/tests/storage_adapter_tests/test_json_file_storage_adapter.py @@ -404,33 +404,3 @@ def test_order_by_created_at(self): self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) self.assertEqual(results[1], statement_b) - - -class ReadOnlyJsonFileStorageAdapterTestCase(JsonAdapterTestCase): - - def test_update_does_not_add_new_statement(self): - self.adapter.read_only = True - - statement = Statement("New statement") - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual(statement_found, None) - - def test_update_does_not_modify_existing_statement(self): - statement = Statement("New statement") - self.adapter.update(statement) - - self.adapter.read_only = True - - statement.add_response( - Response("New response") - ) - - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual(statement_found.text, statement.text) - self.assertEqual( - len(statement_found.in_response_to), 0 - ) diff --git a/tests/storage_adapter_tests/test_mongo_adapter.py b/tests/storage_adapter_tests/test_mongo_adapter.py index 9c63ff62f..c95f2b255 100644 --- a/tests/storage_adapter_tests/test_mongo_adapter.py +++ b/tests/storage_adapter_tests/test_mongo_adapter.py @@ -430,34 +430,3 @@ def test_order_by_created_at(self): self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) self.assertEqual(results[1], statement_b) - - -class ReadOnlyMongoDatabaseAdapterTestCase(MongoAdapterTestCase): - - def test_update_does_not_add_new_statement(self): - self.adapter.read_only = True - - statement = Statement("New statement") - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual(statement_found, None) - - def test_update_does_not_modify_existing_statement(self): - statement = Statement("New statement") - self.adapter.update(statement) - - self.adapter.read_only = True - - statement.add_response( - Response("New response") - ) - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual( - statement_found.text, statement.text - ) - self.assertEqual( - len(statement_found.in_response_to), 0 - ) diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 30034fb91..99fb4ef0f 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -3,10 +3,10 @@ from chatterbot.conversation import Statement, Response -class ChatterBotResponseTests(ChatBotTestCase): +class ChatterBotResponseTestCase(ChatBotTestCase): def setUp(self): - super(ChatterBotResponseTests, self).setUp() + super(ChatterBotResponseTestCase, self).setUp() response_list = [ Response('Hi') @@ -19,16 +19,16 @@ def test_empty_database(self): If there is no statements in the database, then the user's input is the only thing that can be returned. """ - response = self.chatbot.get_response("How are you?") + response = self.chatbot.get_response('How are you?') - self.assertEqual("How are you?", response) + self.assertEqual('How are you?', response) def test_statement_saved_empty_database(self): """ Test that when database is empty, the first statement is saved and returned as a response. """ - statement_text = "Wow!" + statement_text = 'Wow!' response = self.chatbot.get_response(statement_text) saved_statement = self.chatbot.storage.find(statement_text) @@ -40,7 +40,7 @@ def test_statement_added_to_recent_response_list(self): """ An input statement should be added to the recent response list. """ - statement_text = "Wow!" + statement_text = 'Wow!' response = self.chatbot.get_response(statement_text) session = self.chatbot.conversation_sessions.get( self.chatbot.default_session.id_string @@ -52,34 +52,34 @@ def test_statement_added_to_recent_response_list(self): def test_response_known(self): self.chatbot.storage.update(self.test_statement) - response = self.chatbot.get_response("Hi") + response = self.chatbot.get_response('Hi') self.assertEqual(response, self.test_statement.text) def test_response_format(self): self.chatbot.storage.update(self.test_statement) - response = self.chatbot.get_response("Hi") + response = self.chatbot.get_response('Hi') statement_object = self.chatbot.storage.find(response.text) self.assertEqual(response, self.test_statement.text) self.assertIsLength(statement_object.in_response_to, 1) - self.assertIn("Hi", statement_object.in_response_to) + self.assertIn('Hi', statement_object.in_response_to) def test_second_response_format(self): self.chatbot.storage.update(self.test_statement) - response = self.chatbot.get_response("Hi") - # response = "Hello" - second_response = self.chatbot.get_response("How are you?") + response = self.chatbot.get_response('Hi') + # response = 'Hello' + second_response = self.chatbot.get_response('How are you?') statement = self.chatbot.storage.find(second_response.text) # Make sure that the second response was saved to the database - self.assertIsNotNone(self.chatbot.storage.find("How are you?")) + self.assertIsNotNone(self.chatbot.storage.find('How are you?')) self.assertEqual(second_response, self.test_statement.text) self.assertIsLength(statement.in_response_to, 1) - self.assertIn("Hi", statement.in_response_to) + self.assertIn('Hi', statement.in_response_to) def test_get_response_unicode(self): """ @@ -94,7 +94,7 @@ def test_response_extra_data(self): `extra_data` attribute of a statement object, that data should saved with the input statement. """ - self.test_statement.add_extra_data("test", 1) + self.test_statement.add_extra_data('test', 1) self.chatbot.get_response( self.test_statement ) @@ -103,8 +103,8 @@ def test_response_extra_data(self): self.test_statement.text ) - self.assertIn("test", saved_statement.extra_data) - self.assertEqual(1, saved_statement.extra_data["test"]) + self.assertIn('test', saved_statement.extra_data) + self.assertEqual(1, saved_statement.extra_data['test']) def test_generate_response(self): statement = Statement('Many insects adopt a tripedal gait for rapid yet stable walking.') @@ -122,6 +122,19 @@ def test_learn_response(self): self.assertIsNotNone(exists) + def test_update_does_not_add_new_statement(self): + """ + Test that a new statement is not learned if `read_only` is set to True. + """ + self.chatbot.read_only = True + self.chatbot.storage.update(self.test_statement) + + response = self.chatbot.get_response('Hi!') + statement_found = self.chatbot.storage.find('Hi!') + + self.assertEqual(response, self.test_statement.text) + self.assertIsNone(statement_found) + class ChatBotConfigFileTestCase(ChatBotTestCase): diff --git a/tests_django/test_django_adapter.py b/tests_django/test_django_adapter.py index 8ca5b0082..c2d904124 100644 --- a/tests_django/test_django_adapter.py +++ b/tests_django/test_django_adapter.py @@ -334,29 +334,3 @@ def test_order_by_created_at(self): self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) self.assertEqual(results[1], statement_b) - - -class ReadOnlyDjangoAdapterTestCase(DjangoAdapterTestCase): - - def test_update_does_not_add_new_statement(self): - self.adapter.read_only = True - - statement = StatementModel(text="New statement") - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual(statement_found, None) - - def test_update_does_not_modify_existing_statement(self): - statement = StatementModel.objects.create(text="New statement") - - self.adapter.read_only = True - - statement.add_response( - StatementModel(text="New response") - ) - self.adapter.update(statement) - - statement_found = self.adapter.find("New statement") - self.assertEqual(statement_found.text, statement.text) - self.assertEqual(len(statement_found.in_response_to), 0)