Skip to content

Commit

Permalink
Only control write permission in the ChatBot class
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Jan 7, 2017
1 parent d81039f commit 001cfed
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 212 deletions.
6 changes: 5 additions & 1 deletion chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, name, **kwargs):

self.logger = kwargs.get('logger', logging.getLogger(__name__))

# Allow the bot to save input it receives so that it can learn
self.read_only = kwargs.get('read_only', False)

if kwargs.get('initialize', True):
self.initialize()

Expand Down Expand Up @@ -138,7 +141,8 @@ def learn_response(self, statement, previous_statement):
))

# Save the statement after selecting a response
self.storage.update(statement)
if not self.read_only:
self.storage.update(statement)

def set_trainer(self, training_class, **kwargs):
"""
Expand Down
36 changes: 17 additions & 19 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,34 @@ def filter(self, **kwargs):

return statements

def update(self, statement, **kwargs):
def update(self, statement):
"""
Update the provided statement.
"""
from chatterbot.ext.django_chatterbot.models import Statement as StatementModel

response_statement_cache = statement.response_statement_cache

# Do not alter the database unless writing is enabled
if not self.read_only:
statement, created = StatementModel.objects.get_or_create(text=statement.text)
statement.extra_data = getattr(statement, 'extra_data', '')
statement.save()
statement, created = StatementModel.objects.get_or_create(text=statement.text)
statement.extra_data = getattr(statement, 'extra_data', '')
statement.save()

for _response_statement in response_statement_cache:
for _response_statement in response_statement_cache:

response_statement, created = StatementModel.objects.get_or_create(
text=_response_statement.text
)
response_statement.extra_data = getattr(_response_statement, 'extra_data', '')
response_statement.save()
response_statement, created = StatementModel.objects.get_or_create(
text=_response_statement.text
)
response_statement.extra_data = getattr(_response_statement, 'extra_data', '')
response_statement.save()

response, created = statement.in_response.get_or_create(
statement=statement,
response=response_statement
)
response, created = statement.in_response.get_or_create(
statement=statement,
response=response_statement
)

if not created:
response.occurrence += 1
response.save()
if not created:
response.occurrence += 1
response.save()

return statement

Expand Down
32 changes: 13 additions & 19 deletions chatterbot/storage/jsonfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class JsonFileStorageAdapter(StorageAdapter):
:keyword silence_performance_warning: If set to True, the :code:`UnsuitableForProductionWarning`
will not be displayed.
:type silence_performance_warning: bool
:keyword read_only: If set to True, ChatterBot will not save information to the database.
False by default.
:type read_only: bool
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -154,24 +150,22 @@ def filter(self, **kwargs):

return results

def update(self, statement, **kwargs):
def update(self, statement):
"""
Update a statement in the database.
"""
# Do not alter the database unless writing is enabled
if not self.read_only:
data = statement.serialize()

# Remove the text key from the data
del data['text']
self.database.data(key=statement.text, value=data)

# Make sure that an entry for each response exists
for response_statement in statement.in_response_to:
response = self.find(response_statement.text)
if not response:
response = self.Statement(response_statement.text)
self.update(response)
data = statement.serialize()

# Remove the text key from the data
del data['text']
self.database.data(key=statement.text, value=data)

# Make sure that an entry for each response exists
for response_statement in statement.in_response_to:
response = self.find(response_statement.text)
if not response:
response = self.Statement(response_statement.text)
self.update(response)

return statement

Expand Down
52 changes: 22 additions & 30 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ class MongoDatabaseAdapter(StorageAdapter):
.. code-block:: python
database_uri='mongodb://example.com:8100/'
:keyword read_only: If set to True, ChatterBot will not save information to the database.
False by default.
:type read_only: bool
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -206,41 +201,38 @@ def filter(self, **kwargs):

return results

def update(self, statement, **kwargs):
def update(self, statement):
from pymongo import UpdateOne
from pymongo.errors import BulkWriteError

force = kwargs.get('force', False)
# Do not alter the database unless writing is enabled
if force or not self.read_only:
data = statement.serialize()
data = statement.serialize()

operations = []

update_operation = UpdateOne(
{'text': statement.text},
{'$set': data},
upsert=True
)
operations.append(update_operation)

operations = []
# Make sure that an entry for each response is saved
for response_dict in data.get('in_response_to', []):
response_text = response_dict.get('text')

# $setOnInsert does nothing if the document is not created
update_operation = UpdateOne(
{'text': statement.text},
{'$set': data},
{'text': response_text},
{'$set': response_dict},
upsert=True
)
operations.append(update_operation)

# Make sure that an entry for each response is saved
for response_dict in data.get('in_response_to', []):
response_text = response_dict.get('text')

# $setOnInsert does nothing if the document is not created
update_operation = UpdateOne(
{'text': response_text},
{'$set': response_dict},
upsert=True
)
operations.append(update_operation)

try:
self.statements.bulk_write(operations, ordered=False)
except BulkWriteError as bwe:
# Log the details of a bulk write error
self.logger.error(str(bwe.details))
try:
self.statements.bulk_write(operations, ordered=False)
except BulkWriteError as bwe:
# Log the details of a bulk write error
self.logger.error(str(bwe.details))

return statement

Expand Down
1 change: 0 additions & 1 deletion chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, base_query=None, *args, **kwargs):
"""
self.kwargs = kwargs
self.logger = kwargs.get('logger', logging.getLogger(__name__))
self.read_only = kwargs.get('read_only', False)
self.adapter_supports_queries = True
self.base_query = None

Expand Down
6 changes: 3 additions & 3 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def train(self, conversation):
)

statement_history.append(statement)
self.storage.update(statement, force=True)
self.storage.update(statement)


class ChatterBotCorpusTrainer(Trainer):
Expand Down Expand Up @@ -204,7 +204,7 @@ def train(self):
for _ in range(0, 10):
statements = self.get_statements()
for statement in statements:
self.storage.update(statement, force=True)
self.storage.update(statement)


class UbuntuCorpusTrainer(Trainer):
Expand Down Expand Up @@ -345,4 +345,4 @@ def train(self):
)

statement_history.append(statement)
self.storage.update(statement, force=True)
self.storage.update(statement)
29 changes: 0 additions & 29 deletions tests/storage_adapter_tests/integration_tests/base.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from tests.base_case import ChatBotTestCase
from .base import StorageIntegrationTests


class JsonStorageIntegrationTests(StorageIntegrationTests, ChatBotTestCase):
pass
class JsonStorageIntegrationTests(ChatBotTestCase):

def test_database_is_updated(self):
"""
Test that the database is updated when read_only is set to false.
"""
input_text = 'What is the airspeed velocity of an unladen swallow?'
exists_before = self.chatbot.storage.find(input_text)

response = self.chatbot.get_response(input_text)
exists_after = self.chatbot.storage.find(input_text)

self.assertFalse(exists_before)
self.assertTrue(exists_after)
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from tests.base_case import ChatBotMongoTestCase
from .base import StorageIntegrationTests


class MongoStorageIntegrationTests(StorageIntegrationTests, ChatBotMongoTestCase):
pass
class MongoStorageIntegrationTests(ChatBotMongoTestCase):

def test_database_is_updated(self):
"""
Test that the database is updated when read_only is set to false.
"""
input_text = 'What is the airspeed velocity of an unladen swallow?'
exists_before = self.chatbot.storage.find(input_text)

response = self.chatbot.get_response(input_text)
exists_after = self.chatbot.storage.find(input_text)

self.assertFalse(exists_before)
self.assertTrue(exists_after)
30 changes: 0 additions & 30 deletions tests/storage_adapter_tests/test_json_file_storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,33 +404,3 @@ def test_order_by_created_at(self):
self.assertEqual(len(results), 2)
self.assertEqual(results[0], statement_a)
self.assertEqual(results[1], statement_b)


class ReadOnlyJsonFileStorageAdapterTestCase(JsonAdapterTestCase):

def test_update_does_not_add_new_statement(self):
self.adapter.read_only = True

statement = Statement("New statement")
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found, None)

def test_update_does_not_modify_existing_statement(self):
statement = Statement("New statement")
self.adapter.update(statement)

self.adapter.read_only = True

statement.add_response(
Response("New response")
)

self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found.text, statement.text)
self.assertEqual(
len(statement_found.in_response_to), 0
)
31 changes: 0 additions & 31 deletions tests/storage_adapter_tests/test_mongo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,34 +430,3 @@ def test_order_by_created_at(self):
self.assertEqual(len(results), 2)
self.assertEqual(results[0], statement_a)
self.assertEqual(results[1], statement_b)


class ReadOnlyMongoDatabaseAdapterTestCase(MongoAdapterTestCase):

def test_update_does_not_add_new_statement(self):
self.adapter.read_only = True

statement = Statement("New statement")
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(statement_found, None)

def test_update_does_not_modify_existing_statement(self):
statement = Statement("New statement")
self.adapter.update(statement)

self.adapter.read_only = True

statement.add_response(
Response("New response")
)
self.adapter.update(statement)

statement_found = self.adapter.find("New statement")
self.assertEqual(
statement_found.text, statement.text
)
self.assertEqual(
len(statement_found.in_response_to), 0
)
Loading

0 comments on commit 001cfed

Please sign in to comment.