diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index a8ace4ad2..1d6516411 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -142,9 +142,7 @@ def learn_response(self, statement, previous_statement): from .conversation import Response if previous_statement: - statement.add_response( - Response(previous_statement.text) - ) + statement.in_response_to = previous_statement self.logger.info('Adding "{}" as a response to "{}"'.format( statement.text, previous_statement.text diff --git a/chatterbot/conversation/__init__.py b/chatterbot/conversation/__init__.py index a4213715c..0a001a13e 100644 --- a/chatterbot/conversation/__init__.py +++ b/chatterbot/conversation/__init__.py @@ -1,5 +1,2 @@ from .statement import Statement -from .response import Response from .session import Session - -Conversation = Session diff --git a/chatterbot/conversation/response.py b/chatterbot/conversation/response.py deleted file mode 100644 index a1aff3fd8..000000000 --- a/chatterbot/conversation/response.py +++ /dev/null @@ -1,34 +0,0 @@ -class Response(object): - """ - A response represents an entity which response to a statement. - """ - - def __init__(self, text, **kwargs): - self.text = text - self.occurrence = kwargs.get('occurrence', 1) - - def __str__(self): - return self.text - - def __repr__(self): - return '' % (self.text) - - def __hash__(self): - return hash(self.text) - - def __eq__(self, other): - if not other: - return False - - if isinstance(other, Response): - return self.text == other.text - - return self.text == other - - def serialize(self): - data = {} - - data['text'] = self.text - data['occurrence'] = self.occurrence - - return data diff --git a/chatterbot/conversation/session.py b/chatterbot/conversation/session.py index b98d7e208..cdb3751ec 100644 --- a/chatterbot/conversation/session.py +++ b/chatterbot/conversation/session.py @@ -16,7 +16,10 @@ def all(self): """ Return all statements in the conversation. """ - return self.storage.filter(conversation__id=self.conversation_id) + return self.storage.filter( + conversation__id=self.conversation_id, + order_by='created_at' + ) def add(self, statement): """ @@ -25,6 +28,12 @@ def add(self, statement): statement.conversation_id = self.conversation_id self.storage.update(statement) + def count(self): + return len(self.all()) + + def exists(self): + return self.count() > 0 + class Session(object): """ @@ -52,7 +61,7 @@ def get_last_response_statement(self): statements = self.statements.all() if statements: # Return the latest output statement (This should be ordering them by date to get the latest) - return statements[-1] + return statements[1] return None diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index 6b8f3c2a6..885106b67 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -1,8 +1,27 @@ # -*- coding: utf-8 -*- -from .response import Response from datetime import datetime +class StatementSerializer(object): + + def serialize(self, obj): + """ + :returns: A dictionary representation of the statement object. + :rtype: dict + """ + data = {} + + data['text'] = obj.text + data['in_response_to'] = {'text': obj.in_response_to.text} + data['created_at'] = obj.created_at + data['extra_data'] = obj.extra_data + + return data + + def deserialize(self, data): + pass + + class Statement(object): """ A statement represents a single spoken entity, sentence or @@ -11,8 +30,7 @@ class Statement(object): def __init__(self, text, **kwargs): self.text = text - self.conversation_id = kwargs.pop('conversation_id', None) - self.in_response_to = kwargs.pop('in_response_to', []) + self.in_response_to = kwargs.pop('in_response_to', None) # The date and time that this statement was created at self.created_at = kwargs.pop('created_at', datetime.now()) @@ -68,81 +86,9 @@ def add_extra_data(self, key, value): """ self.extra_data[key] = value - def add_response(self, response): - """ - Add the response to the list of statements that this statement is in response to. - If the response is already in the list, increment the occurrence count of that response. - - :param response: The response to add. - :type response: `Response` - """ - if not isinstance(response, Response): - raise Statement.InvalidTypeException( - 'A {} was recieved when a {} instance was expected'.format( - type(response), - type(Response('')) - ) - ) - - updated = False - for index in range(0, len(self.in_response_to)): - if response.text == self.in_response_to[index].text: - self.in_response_to[index].occurrence += 1 - updated = True - - if not updated: - self.in_response_to.append(response) - - def remove_response(self, response_text): - """ - Removes a response from the statement's response list based - on the value of the response text. - - :param response_text: The text of the response to be removed. - :type response_text: str - """ - for response in self.in_response_to: - if response_text == response.text: - self.in_response_to.remove(response) - return True - return False - - def get_response_count(self, statement): - """ - Find the number of times that the statement has been used - as a response to the current statement. - - :param statement: The statement object to get the count for. - :type statement: `Statement` - - :returns: Return the number of times the statement has been used as a response. - :rtype: int - """ - for response in self.in_response_to: - if statement.text == response.text: - return response.occurrence - - return 0 - def serialize(self): - """ - :returns: A dictionary representation of the statement object. - :rtype: dict - """ - data = {} - - data['text'] = self.text - data['in_response_to'] = [] - data['created_at'] = self.created_at - data['extra_data'] = self.extra_data - - if self.conversation_id: - data['conversation_id'] = self.conversation_id - - for response in self.in_response_to: - data['in_response_to'].append(response.serialize()) - - return data + serializer = StatementSerializer() + return serializer.serialize(self) @property def response_statement_cache(self): diff --git a/chatterbot/ext/django_chatterbot/admin.py b/chatterbot/ext/django_chatterbot/admin.py index 5d337eb8c..12b614739 100644 --- a/chatterbot/ext/django_chatterbot/admin.py +++ b/chatterbot/ext/django_chatterbot/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from chatterbot.ext.django_chatterbot.models import Statement, Response, Conversation +from chatterbot.ext.django_chatterbot.models import Statement, Conversation class StatementAdmin(admin.ModelAdmin): @@ -8,15 +8,9 @@ class StatementAdmin(admin.ModelAdmin): search_fields = ('text', ) -class ResponseAdmin(admin.ModelAdmin): - list_display = ('statement', 'occurrence', ) - - class ConversationAdmin(admin.ModelAdmin): - list_display = ('statement', 'occurrence', ) list_display = ('root', ) admin.site.register(Statement, StatementAdmin) -admin.site.register(Response, ResponseAdmin) admin.site.register(Conversation, ConversationAdmin) diff --git a/chatterbot/ext/django_chatterbot/models.py b/chatterbot/ext/django_chatterbot/models.py index fce2e3b98..bd87a2f2b 100644 --- a/chatterbot/ext/django_chatterbot/models.py +++ b/chatterbot/ext/django_chatterbot/models.py @@ -9,19 +9,11 @@ class Statement(models.Model): """ text = models.CharField( - unique=True, blank=False, null=False, max_length=255 ) - conversation = models.ForeignKey( - 'Conversation', - related_name='statements', - blank=True, - null=True - ) - created_at = models.DateTimeField( default=timezone.now, help_text='The date and time that this statement was created at.' @@ -29,6 +21,12 @@ class Statement(models.Model): extra_data = models.CharField(max_length=500) + response = models.OneToOneField( + 'Statement', + related_name='in_response_to', + null=True + ) + # This is the confidence with which the chat bot believes # this is an accurate response. This value is set when the # statement is returned by the chat bot. @@ -47,12 +45,11 @@ def __init__(self, *args, **kwargs): # Responses to be saved if the statement is updated with the storage adapter self.response_statement_cache = [] - @property - def in_response_to(self): + def responses(self): """ - Return the response objects that are for this statement. + Return a list of statements that are known responses to this statement. """ - return Response.objects.filter(statement=self) + return Statement.objects.filter(in_response_to__text=self.text) def add_extra_data(self, key, value): """ @@ -68,45 +65,6 @@ def add_extra_data(self, key, value): self.extra_data = json.dumps(extra_data) - def add_response(self, statement): - """ - Add a response to this statement. - """ - self.response_statement_cache.append(statement) - - def remove_response(self, response_text): - """ - Removes a response from the statement's response list based - on the value of the response text. - - :param response_text: The text of the response to be removed. - :type response_text: str - """ - is_deleted = False - response = self.in_response.filter(response__text=response_text) - - if response.exists(): - is_deleted = True - - return is_deleted - - def get_response_count(self, statement): - """ - Find the number of times that the statement has been used - as a response to the current statement. - - :param statement: The statement object to get the count for. - :type statement: chatterbot.conversation.statement.Statement - - :returns: Return the number of times the statement has been used as a response. - :rtype: int - """ - try: - response = self.in_response.get(response__text=statement.text) - return response.occurrence - except Response.DoesNotExist: - return 0 - def serialize(self): """ :returns: A dictionary representation of the statement object. @@ -119,56 +77,10 @@ def serialize(self): self.extra_data = '{}' data['text'] = self.text - data['in_response_to'] = [] + data['in_response_to'] = {'text': self.in_response_to.text} data['created_at'] = self.created_at data['extra_data'] = json.loads(self.extra_data) - for response in self.in_response.all(): - data['in_response_to'].append(response.serialize()) - - return data - - -class Response(models.Model): - """ - Connection between a response and the statement that triggered it. - - Comparble to a ManyToMany "through" table, but without the M2M indexing/relations. - The text and number of times the response has occurred are stored. - """ - - statement = models.ForeignKey( - 'Statement', - related_name='in_response' - ) - - response = models.ForeignKey( - 'Statement', - related_name='responses' - ) - - unique_together = (('statement', 'response'),) - - occurrence = models.PositiveIntegerField(default=1) - - def __str__(self): - statement = self.statement.text - response = self.response.text - return '{} => {}'.format( - statement if len(statement) <= 20 else statement[:17] + '...', - response if len(response) <= 40 else response[:37] + '...' - ) - - def serialize(self): - """ - :returns: A dictionary representation of the statement object. - :rtype: dict - """ - data = {} - - data['text'] = self.response.text - data['occurrence'] = self.occurrence - return data diff --git a/chatterbot/ext/django_chatterbot/views.py b/chatterbot/ext/django_chatterbot/views.py index 7f3ab81ba..bb044611c 100644 --- a/chatterbot/ext/django_chatterbot/views.py +++ b/chatterbot/ext/django_chatterbot/views.py @@ -44,13 +44,10 @@ class ChatterBotView(ChatterBotViewMixin, View): """ def _serialize_conversation(self, session): - if session.conversation.empty(): - return [] - conversation = [] - for statement, response in session.conversation: - conversation.append([statement.serialize(), response.serialize()]) + for statement in session.statements: + conversation.append(statement.serialize()) return conversation diff --git a/chatterbot/filters.py b/chatterbot/filters.py index 1b45574d6..ae41fcbea 100644 --- a/chatterbot/filters.py +++ b/chatterbot/filters.py @@ -28,13 +28,13 @@ def filter_selection(self, chatterbot, session_id): session = chatterbot.conversation_sessions.get(session_id) - if session.conversation.empty(): + if not session.statements.exists(): return chatterbot.storage.base_query text_of_recent_responses = [] - for statement, response in session.conversation: - text_of_recent_responses.append(response.text) + for statement in session.statements: + text_of_recent_responses.append(statement.text) query = chatterbot.storage.base_query.statement_text_not_in( text_of_recent_responses diff --git a/chatterbot/storage/django_storage.py b/chatterbot/storage/django_storage.py index 4c1dd7005..d6a6df41a 100644 --- a/chatterbot/storage/django_storage.py +++ b/chatterbot/storage/django_storage.py @@ -109,6 +109,7 @@ def get_random(self): Returns a random statement from the database """ from chatterbot.ext.django_chatterbot.models import Statement + return Statement.objects.order_by('?').first() def remove(self, statement_text): @@ -118,25 +119,15 @@ def remove(self, statement_text): input text. """ from chatterbot.ext.django_chatterbot.models import Statement - from chatterbot.ext.django_chatterbot.models import Response - from django.db.models import Q - statements = Statement.objects.filter(text=statement_text) - - responses = Response.objects.filter( - Q(statement__text=statement_text) | Q(response__text=statement_text) - ) - - responses.delete() - statements.delete() + Statement.objects.filter(text=statement_text).delete() def drop(self): """ Remove all data from the database. """ - from chatterbot.ext.django_chatterbot.models import Statement, Response, Conversation + from chatterbot.ext.django_chatterbot.models import Statement, Conversation Statement.objects.all().delete() - Response.objects.all().delete() Conversation.objects.all().delete() def get_response_statements(self): @@ -146,8 +137,6 @@ def get_response_statements(self): in_response_to field. Otherwise, the logic adapter may find a closest matching statement that does not have a known response. """ - from chatterbot.ext.django_chatterbot.models import Statement, Response - - responses = Response.objects.all() + from chatterbot.ext.django_chatterbot.models import Statement - return Statement.objects.filter(in_response__in=responses) + return Statement.objects.filter(in_response_to__isnull=False) diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index c81ef8941..702ea9745 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -1,6 +1,5 @@ import warnings from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Response class JsonFileStorageAdapter(StorageAdapter): @@ -29,14 +28,13 @@ def __init__(self, **kwargs): database_path = self.kwargs.get('database', 'database.db') self.database = Database(database_path) - self.adapter_supports_queries = False + # Create the statements document as an empty list + self.database['statements'] = [] - def _keys(self): - # The value has to be cast as a list for Python 3 compatibility - return list(self.database[0].keys()) + self.adapter_supports_queries = False def count(self): - return len(self._keys()) + return len(self.database['statements'].keys()) def find(self, statement_text): values = self.database.data(key=statement_text) @@ -60,24 +58,6 @@ def remove(self, statement_text): self.database.delete(statement_text) - def deserialize_responses(self, response_list): - """ - Takes the list of response items and returns - the list converted to Response objects. - """ - proxy_statement = self.Statement('') - - for response in response_list: - data = response.copy() - text = data['text'] - del data['text'] - - proxy_statement.add_response( - Response(text, **data) - ) - - return proxy_statement.in_response_to - def json_to_object(self, statement_data): """ Converts a dictionary-like object to a Statement object. @@ -87,8 +67,8 @@ def json_to_object(self, statement_data): statement_data = statement_data.copy() # Build the objects for the response list - statement_data['in_response_to'] = self.deserialize_responses( - statement_data['in_response_to'] + statement_data['in_response_to'] = self.Statement( + **statement_data['in_response_to'] ) # Remove the text attribute from the values @@ -131,14 +111,10 @@ def filter(self, **kwargs): order_by = kwargs.pop('order_by', None) - for key in self._keys(): - values = self.database.data(key=key) + for statement in self.database['statements']: - # Add the text attribute to the values - values['text'] = key - - if self._all_kwargs_match_values(kwargs, values): - results.append(self.json_to_object(values)) + if self._all_kwargs_match_values(kwargs, statement): + results.append(self.json_to_object(statement)) if order_by: @@ -154,18 +130,10 @@ def update(self, statement): """ Update a statement in the database. """ - data = statement.serialize() - - # Remove the text key from the data - del data['text'] - self.database.data(key=statement.text, value=data) + statements = self.database['statements'] + statements.append(statement.serialize()) - # Make sure that an entry for each response exists - for response_statement in statement.in_response_to: - response = self.find(response_statement.text) - if not response: - response = self.Statement(response_statement.text) - self.update(response) + self.database.data(key='statements', value=statements) return statement @@ -175,8 +143,7 @@ def get_random(self): if self.count() < 1: raise self.EmptyDatabaseException() - statement = choice(self._keys()) - return self.find(statement) + return choice(self.database['statements']) def drop(self): """ diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index e025c9c3e..0cc4bb1c2 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -1,5 +1,4 @@ from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Response class Query(object): @@ -98,9 +97,6 @@ def __init__(self, **kwargs): # The mongo collection of statement documents self.statements = self.database['statements'] - # Set a requirement for the text attribute to be unique - self.statements.create_index('text', unique=True) - self.base_query = Query() def count(self): @@ -134,10 +130,6 @@ def deserialize_responses(self, response_list): text = response['text'] del response['text'] - proxy_statement.add_response( - Response(text, **response) - ) - return proxy_statement.in_response_to def mongo_to_object(self, statement_data): diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index 755b68b29..decd49a96 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -40,19 +40,6 @@ def generate_base_query(self, chatterbot, session_id): for filter_instance in chatterbot.filters: self.base_query = filter_instance.filter_selection(chatterbot, session_id) - def create_conversation(self): - """ - Returns a new storage-aware conversation instance. - """ - import os - - if 'DJANGO_SETTINGS_MODULE' in os.environ: - from chatterbot.ext.django_chatterbot.models import Conversation - return Conversation.objects.create() - else: - from chatterbot.conversation import Conversation - return Conversation(self) - def count(self): """ Return the number of entries in the database. diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 6946c437a..515719e1f 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -1,5 +1,5 @@ import logging -from .conversation import Statement, Response +from .conversation import Statement class Trainer(object): @@ -82,9 +82,7 @@ def train(self, conversation): statement = self.get_or_create(text) if statement_history: - statement.add_response( - Response(statement_history[-1].text) - ) + statement.in_response_to = statement_history[-1] statement_history.append(statement) self.storage.update(statement) @@ -197,7 +195,7 @@ def get_statements(self): if tweet.in_reply_to_status_id: try: status = self.api.GetStatus(tweet.in_reply_to_status_id) - statement.add_response(Response(status.text)) + statement.in_response_to = Statement(status.text) statements.append(statement) except TwitterError as error: self.logger.warning(str(error)) @@ -346,9 +344,7 @@ def train(self): statement.add_extra_data('addressing_speaker', row[2]) if statement_history: - statement.add_response( - Response(statement_history[-1].text) - ) + statement.in_response_to = statement_history[-1] statement_history.append(statement) self.storage.update(statement) diff --git a/tests/conversation_tests/test_responses.py b/tests/conversation_tests/test_responses.py deleted file mode 100644 index f1a71cf81..000000000 --- a/tests/conversation_tests/test_responses.py +++ /dev/null @@ -1,9 +0,0 @@ -from unittest import TestCase -from chatterbot.conversation import Response - - -class ResponseTests(TestCase): - - def setUp(self): - self.response = Response("A test response.") - diff --git a/tests/conversation_tests/test_statements.py b/tests/conversation_tests/test_statements.py index 8192c45fd..986ce7754 100644 --- a/tests/conversation_tests/test_statements.py +++ b/tests/conversation_tests/test_statements.py @@ -1,22 +1,18 @@ # -*- coding: utf-8 -*- from unittest import TestCase -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement class StatementTests(TestCase): - def setUp(self): - self.statement = Statement("A test statement.") - def test_list_equality(self): """ It should be possible to check if a statement exists in the list of statements that another statement has been issued in response to. """ - self.statement.add_response(Response("Yo")) - self.assertEqual(len(self.statement.in_response_to), 1) - self.assertIn(Response("Yo"), self.statement.in_response_to) + statements = [Statement('Hi'), Statement('Hello')] + self.assertEqual(Statement('Hi'), statements) def test_list_equality_unicode(self): """ @@ -24,62 +20,5 @@ def test_list_equality_unicode(self): is in a list of other statements when the statements text is unicode. """ - statements = [Statement("Hello"), Statement("我很好太感谢")] - statement = Statement("我很好太感谢") - self.assertIn(statement, statements) - - def test_update_response_list_new(self): - self.statement.add_response(Response("Hello")) - self.assertTrue(len(self.statement.in_response_to), 1) - - def test_update_response_list_existing(self): - response = Response("Hello") - self.statement.add_response(response) - self.statement.add_response(response) - self.assertTrue(len(self.statement.in_response_to), 1) - - def test_remove_response_exists(self): - self.statement.add_response(Response("Testing")) - removed = self.statement.remove_response("Testing") - self.assertTrue(removed) - - def test_remove_response_does_not_exist(self): - self.statement.add_response(Response("Testing")) - removed = self.statement.remove_response("Test") - self.assertFalse(removed) - - def test_serializer(self): - data = self.statement.serialize() - self.assertEqual(self.statement.text, data["text"]) - - def test_occurrence_count_for_new_statement(self): - """ - When the occurrence is updated for a statement that - previously did not exist as a statement that the current - statement was issued in response to, then the new statement - should be added to the response list and the occurence count - for that response should be set to 1. - """ - response = Response("This is a test.") - - self.statement.add_response(response) - self.assertTrue(self.statement.get_response_count(response), 1) - - def test_occurrence_count_for_existing_statement(self): - self.statement.add_response(Response("ABC")) - self.statement.add_response(Response("ABC")) - self.assertTrue( - self.statement.get_response_count(Response("ABC")), - 2 - ) - - def test_occurrence_count_incremented(self): - self.statement.add_response(Response("ABC")) - self.statement.add_response(Response("ABC")) - - self.assertEqual(len(self.statement.in_response_to), 1) - self.assertEqual(self.statement.in_response_to[0].occurrence, 2) - - def test_add_non_response(self): - with self.assertRaises(Statement.InvalidTypeException): - self.statement.add_response(Statement("Blah")) + statements = [Statement('Hello'), Statement('我很好太感谢')] + self.assertIn(Statement('我很好太感谢'), statements) diff --git a/tests/storage_adapter_tests/test_json_file_storage_adapter.py b/tests/storage_adapter_tests/test_json_file_storage_adapter.py index 26b0ffce1..264c30798 100644 --- a/tests/storage_adapter_tests/test_json_file_storage_adapter.py +++ b/tests/storage_adapter_tests/test_json_file_storage_adapter.py @@ -95,16 +95,12 @@ def test_update_modifies_existing_statement(self): ) # Update the statement value - statement.add_response( - Response("New response") - ) + statement.in_response_to = Statement('New response') self.adapter.update(statement) # Check that the values have changed found_statement = self.adapter.find(statement.text) - self.assertEqual( - len(found_statement.in_response_to), 1 - ) + self.assertEqual(found_statement.in_response_to, statement.in_response_to) def test_get_random_returns_statement(self): statement = Statement("New statement") @@ -158,17 +154,20 @@ def test_update_saves_statement_with_multiple_responses(self): self.assertEqual(len(response.in_response_to), 2) def test_getting_and_updating_statement(self): - statement = Statement("Hi") + statement1 = Statement('Hi', in_response_to=Statement('Hello')) + statement2 = Statement('Hi', in_response_to=Statement('Hello')) self.adapter.update(statement) - statement.add_response(Response("Hello")) - statement.add_response(Response("Hello")) - self.adapter.update(statement) + self.adapter.update(statement1) + self.adapter.update(statement2) - response = self.adapter.find(statement.text) + results = self.adapter.filter(statement.text) - self.assertEqual(len(response.in_response_to), 1) - self.assertEqual(response.in_response_to[0].occurrence, 2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].text = statement1.text) + self.assertEqual(results[1].text = statement2.text) + self.assertEqual(results[0].in_response_to = statement1.in_response_to) + self.assertEqual(results[1].in_response_to = statement2.in_response_to) def test_deserialize_responses(self): response_list = [ diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 88bb886ff..35bf48af7 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -46,7 +46,7 @@ def test_statement_added_to_recent_response_list(self): self.chatbot.default_session.id_string ) - self.assertIn(statement_text, session.conversation[0]) + self.assertIn(statement_text, session.statements.all()) self.assertEqual(response, statement_text) def test_response_known(self): diff --git a/tests/test_context.py b/tests/test_context.py index 7fa6f411e..5b1bd0fb3 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,4 +1,5 @@ from .base_case import ChatBotTestCase +from chatterbot.conversation import Statement class AdapterTests(ChatBotTestCase): @@ -8,14 +9,15 @@ def test_modify_chatbot(self): When one adapter modifies its chatbot instance, the change should be the same in all other adapters. """ + session = self.chatbot.input.chatbot.conversation_sessions.new() self.chatbot.input.chatbot.conversation_sessions.update( session.id_string, - ('A', 'B', ) - ) + Statement('A'), Statement('B') session = self.chatbot.output.chatbot.conversation_sessions.get( session.id_string ) - self.assertIn(('A', 'B', ), session.conversation) + self.assertEqual(Statement('A'), session.statements.all()[0]) + self.assertEqual(Statement('B'), session.statements.all()[1]) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index fd3db2a36..5dcccf09b 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -1,4 +1,4 @@ -from chatterbot.conversation import Conversation +from chatterbot.conversation import Statement, Conversation from chatterbot.conversation.session import ConversationSessionManager from .base_case import ChatBotTestCase @@ -41,11 +41,13 @@ def test_get_invalid_id_with_deafult(self): def test_update(self): session = self.manager.new() - self.manager.update(session.id_string, ('A', 'B', )) + self.manager.update(session.id_string, (Statement('A'), Statement('B'), )) session_ids = list(self.manager.sessions.keys()) session_id = session_ids[0] self.assertEqual(len(session_ids), 1) - self.assertEqual(len(self.manager.get(session_id).conversation), 1) - self.assertEqual(('A', 'B', ), self.manager.get(session_id).conversation[0]) + self.assertEqual(self.manager.get(session_id).statements.count(), 2) + self.assertEqual(Statement('A'), self.manager.get(session_id).statements.all()[0]) + self.assertEqual(Statement('B'), self.manager.get(session_id).statements.all()[1]) + diff --git a/tests_django/integration_tests/test_statement_integration.py b/tests_django/integration_tests/test_statement_integration.py index 8b29d5a4c..693291e64 100644 --- a/tests_django/integration_tests/test_statement_integration.py +++ b/tests_django/integration_tests/test_statement_integration.py @@ -1,9 +1,7 @@ from django.test import TestCase from django.utils import timezone from chatterbot.conversation import Statement as StatementObject -from chatterbot.conversation import Response as ResponseObject from chatterbot.ext.django_chatterbot.models import Statement as StatementModel -from chatterbot.ext.django_chatterbot.models import Response as ResponseModel class StatementIntegrationTestCase(TestCase): @@ -40,36 +38,6 @@ def test_add_extra_data(self): self.object.add_extra_data('key', 'value') self.model.add_extra_data('key', 'value') - def test_add_response(self): - self.assertTrue(hasattr(self.object, 'add_response')) - self.assertTrue(hasattr(self.model, 'add_response')) - - def test_remove_response(self): - self.object.add_response(ResponseObject('Hello')) - model_response_statement = StatementModel.objects.create(text='Hello') - self.model.save() - self.model.in_response.create(statement=self.model, response=model_response_statement) - - object_removed = self.object.remove_response('Hello') - model_removed = self.model.remove_response('Hello') - - self.assertTrue(object_removed) - self.assertTrue(model_removed) - - def test_get_response_count(self): - self.object.add_response(ResponseObject('Hello', occurrence=2)) - model_response_statement = StatementModel.objects.create(text='Hello') - self.model.save() - self.model.in_response.create( - statement=self.model, response=model_response_statement, occurrence=2 - ) - - object_count = self.object.get_response_count(StatementObject(text='Hello')) - model_count = self.model.get_response_count(StatementModel(text='Hello')) - - self.assertEqual(object_count, 2) - self.assertEqual(model_count, 2) - def test_serialize(self): object_data = self.object.serialize() model_data = self.model.serialize() @@ -83,25 +51,3 @@ def test_serialize(self): def test_response_statement_cache(self): self.assertTrue(hasattr(self.object, 'response_statement_cache')) self.assertTrue(hasattr(self.model, 'response_statement_cache')) - - -class ResponseIntegrationTestCase(TestCase): - - """ - Test case to make sure that the Django Response model - and ChatterBot Response object have a common interface. - """ - - def setUp(self): - super(ResponseIntegrationTestCase, self).setUp() - date_created = timezone.now() - statement_object = StatementObject(text='_', created_at=date_created) - statement_model = StatementModel.objects.create(text='_', created_at=date_created) - self.object = ResponseObject(statement_object.text) - self.model = ResponseModel(statement=statement_model, response=statement_model) - - def test_serialize(self): - object_data = self.object.serialize() - model_data = self.model.serialize() - - self.assertEqual(object_data, model_data) diff --git a/tests_django/test_django_adapter.py b/tests_django/test_django_adapter.py index 496b9592f..187872932 100644 --- a/tests_django/test_django_adapter.py +++ b/tests_django/test_django_adapter.py @@ -2,7 +2,6 @@ from django.test import TestCase from chatterbot.storage import DjangoStorageAdapter from chatterbot.ext.django_chatterbot.models import Statement as StatementModel -from chatterbot.ext.django_chatterbot.models import Response as ResponseModel class DjangoAdapterTestCase(TestCase): @@ -89,16 +88,21 @@ def test_get_random_returns_statement(self): self.assertEqual(random_statement.text, statement.text) def test_find_returns_nested_responses(self): - statement = StatementModel.objects.create(text="Do you like this?") - statement.add_response(StatementModel(text="Yes")) - statement.add_response(StatementModel(text="No")) + question = StatementModel.objects.create(text='Do you like this?') - self.adapter.update(statement) + yes = StatementModel(text='Yes') + yes.in_response_to = question + yes.save() - result = self.adapter.find(statement.text) + no = StatementModel(text='No') + no.in_response_to = question + no.save() + + result = self.adapter.find(question.text) + responses = result.responses() - self.assertTrue(result.in_response_to.filter(response__text="Yes").exists()) - self.assertTrue(result.in_response_to.filter(response__text="No").exists()) + self.assertTrue(responses.in_response_to.filter(text='Yes').exists()) + self.assertTrue(responses.in_response_to.filter(text='No').exists()) def test_multiple_responses_added_on_update(self): statement = StatementModel.objects.create(text="You are welcome.") @@ -305,7 +309,7 @@ def test_response_list_in_results(self): found = self.adapter.filter(text=statement.text) self.assertEqual(len(found[0].in_response_to), 1) - self.assertEqual(type(found[0].in_response_to[0]), ResponseModel) + self.assertEqual(type(found[0].in_response_to[0]), StatementModel) def test_confidence(self): """