diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index bc269ea76..e95cf842c 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -12,7 +12,7 @@ class ChatBot(object): """ def __init__(self, name, **kwargs): - from .conversation.session import ConversationSessionManager + from .conversation.session import ConversationManager from .logic import MultiLogicAdapter self.name = name @@ -71,8 +71,8 @@ def __init__(self, name, **kwargs): self.trainer = TrainerClass(self.storage, **kwargs) self.training_data = kwargs.get('training_data') - self.conversation_sessions = ConversationSessionManager() - self.default_session = self.conversation_sessions.new() + self.conversation_sessions = ConversationManager(self.storage) + self.default_session = None self.logger = kwargs.get('logger', logging.getLogger(__name__)) @@ -102,27 +102,33 @@ def get_response(self, input_item, session_id=None): :returns: A response to the input. :rtype: Statement """ - if not session_id: - session_id = str(self.default_session.uuid) - input_statement = self.input.process_input_statement(input_item) # Preprocess the input statement for preprocessor in self.preprocessors: input_statement = preprocessor(self, input_statement) - statement, response = self.generate_response(input_statement, session_id) + if session_id: + session = self.conversation_sessions.get(session_id) + + if not session: + session = self.get_or_create_default_conversation() + else: + session = self.get_or_create_default_conversation() + + statement, response = self.generate_response(input_statement, session.id) # Learn that the user's input was a valid response to the chat bot's previous output - previous_statement = self.conversation_sessions.get( - session_id - ).conversation.get_last_response_statement() - self.learn_response(statement, previous_statement) + previous_statement = session.get_last_response_statement() + + self.learn_response(statement, previous_statement, session) - self.conversation_sessions.update(session_id, (statement, response, )) + if not self.read_only: + response.save() + session.statements.add(response) # Process the response output with the output adapter - return self.output.process_response(response, session_id) + return self.output.process_response(response, session.id) def generate_response(self, input_statement, session_id): """ @@ -135,16 +141,15 @@ def generate_response(self, input_statement, session_id): return input_statement, response - def learn_response(self, statement, previous_statement): + def learn_response(self, statement, previous_statement, session=None): """ Learn that the statement provided is a valid response. """ - from .conversation import Response + if not session: + session = self.get_or_create_default_conversation() if previous_statement: - statement.add_response( - Response(previous_statement.text) - ) + statement.in_response_to = previous_statement self.logger.info('Adding "{}" as a response to "{}"'.format( statement.text, previous_statement.text @@ -153,6 +158,7 @@ def learn_response(self, statement, previous_statement): # Save the statement after selecting a response if not self.read_only: self.storage.update(statement) + session.statements.add(statement) def set_trainer(self, training_class, **kwargs): """ @@ -165,6 +171,17 @@ def set_trainer(self, training_class, **kwargs): """ self.trainer = training_class(self.storage, **kwargs) + def get_or_create_default_conversation(self): + """ + Get the default conversation session if it exists. + Otherwise create a new conversation. + This is a lazy function designed to only create the conversation if + a statement exists for it. + """ + if not self.default_session: + self.default_session = self.storage.Conversation.objects.create() + return self.default_session + @property def train(self): """ diff --git a/chatterbot/conversation/__init__.py b/chatterbot/conversation/__init__.py index 71af608cb..dbad82ba7 100644 --- a/chatterbot/conversation/__init__.py +++ b/chatterbot/conversation/__init__.py @@ -1,2 +1,2 @@ from .statement import Statement -from .response import Response +from .session import Conversation diff --git a/chatterbot/conversation/response.py b/chatterbot/conversation/response.py deleted file mode 100644 index a1aff3fd8..000000000 --- a/chatterbot/conversation/response.py +++ /dev/null @@ -1,34 +0,0 @@ -class Response(object): - """ - A response represents an entity which response to a statement. - """ - - def __init__(self, text, **kwargs): - self.text = text - self.occurrence = kwargs.get('occurrence', 1) - - def __str__(self): - return self.text - - def __repr__(self): - return '' % (self.text) - - def __hash__(self): - return hash(self.text) - - def __eq__(self, other): - if not other: - return False - - if isinstance(other, Response): - return self.text == other.text - - return self.text == other - - def serialize(self): - data = {} - - data['text'] = self.text - data['occurrence'] = self.occurrence - - return data diff --git a/chatterbot/conversation/session.py b/chatterbot/conversation/session.py index 56a47bf83..70ea512e8 100644 --- a/chatterbot/conversation/session.py +++ b/chatterbot/conversation/session.py @@ -1,50 +1,112 @@ import uuid -from chatterbot.queues import ResponseQueue -class Session(object): +class ConversationModelMixin(object): + + collection_name = 'conversations' + + pk_field = 'id' + fields = ('id', ) + + def get_last_response_statement(self): + """ + Return the last statement that was received. + """ + if self.statements.exists(): + # Return the output statement + return self.statements.latest('id') + return None + + def get_last_input_statement(self): + """ + Return the last response that was given. + """ + if self.statements.count() > 1: + # Return the input statement + return self.statements.all()[-2] + return None + + def serialize(self): + statements = [] + + for statement in self.statements.all(): + statements.append({'text': statement.text}) + + return { + 'id': self.id, + 'statements': statements + } + + +class StatementRelatedManager(object): + + def __init__(self, conversation, statements): + self.statements = statements + self.conversation = conversation + + def exists(self): + return len(self.statements) > 0 + + def count(self): + return len(self.statements) + + def first(self): + return self.statements[0] + + def latest(self, *args): + return self.statements[-1] + + def all(self): + return self.statements + + def add(self, statement): + self.statements.append(statement) + self.conversation.save() + + + +class Conversation(ConversationModelMixin): """ A single chat session. """ - def __init__(self): + objects = None + + def __init__(self, **kwargs): # A unique identifier for the chat session self.uuid = uuid.uuid1() - self.id_string = str(self.uuid) - self.id = str(self.uuid) + self.id = kwargs.get('id', str(self.uuid)) - # The last 10 statement inputs and outputs - self.conversation = ResponseQueue(maxsize=10) + statements = kwargs.get('statements', []) + self.statements = StatementRelatedManager(self, statements) + def save(self): + self.objects.storage.update(self) -class ConversationSessionManager(object): + +class ConversationManager(object): """ Object to hold and manage multiple chat sessions. """ - def __init__(self): - self.sessions = {} + def __init__(self, storage): + self.storage = storage - def new(self): + def create(self): """ Add a new chat session. """ - session = Session() - - self.sessions[session.id_string] = session - - return session + conversation = self.storage.Conversation() + conversation.save() + return conversation def get(self, session_id, default=None): """ Return a session given a unique identifier. """ - return self.sessions.get(str(session_id), default) + results = self.storage.filter(self.storage.Conversation, id=session_id) + if results: + return results[0] + else: + return default - def update(self, session_id, conversance): - """ - Add a conversance to a given session if the session exists. - """ - session_id = str(session_id) - if session_id in self.sessions: - self.sessions[session_id].conversation.append(conversance) diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index a8ec0cde4..b2fa3bf0f 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -1,17 +1,49 @@ # -*- coding: utf-8 -*- -from .response import Response from datetime import datetime -class Statement(object): +class StatementModelMixin(object): + + collection_name = 'statements' + + pk_field = 'text' + fields = ( + 'text', 'created_at', 'in_response_to', 'extra_data', + ) + + +class StatementSerializer(object): + + def serialize(self, obj): + """ + :returns: A dictionary representation of the statement object. + :rtype: dict + """ + data = {} + + data['text'] = obj.text + if obj.in_response_to: + data['in_response_to'] = {'text': obj.in_response_to.text} + data['created_at'] = obj.created_at + data['extra_data'] = obj.extra_data + + return data + + def deserialize(self, data): + pass + + +class Statement(StatementModelMixin): """ A statement represents a single spoken entity, sentence or phrase that someone can say. """ + storage = None + def __init__(self, text, **kwargs): self.text = text - self.in_response_to = kwargs.pop('in_response_to', []) + self.in_response_to = kwargs.pop('in_response_to', None) # The date and time that this statement was created at self.created_at = kwargs.pop('created_at', datetime.now()) @@ -23,8 +55,6 @@ def __init__(self, text, **kwargs): # statement is returned by the chat bot. self.confidence = 0 - self.storage = None - def __str__(self): return self.text @@ -67,86 +97,9 @@ def add_extra_data(self, key, value): """ self.extra_data[key] = value - def add_response(self, response): - """ - Add the response to the list of statements that this statement is in response to. - If the response is already in the list, increment the occurrence count of that response. - - :param response: The response to add. - :type response: `Response` - """ - if not isinstance(response, Response): - raise Statement.InvalidTypeException( - 'A {} was recieved when a {} instance was expected'.format( - type(response), - type(Response('')) - ) - ) - - updated = False - for index in range(0, len(self.in_response_to)): - if response.text == self.in_response_to[index].text: - self.in_response_to[index].occurrence += 1 - updated = True - - if not updated: - self.in_response_to.append(response) - - def remove_response(self, response_text): - """ - Removes a response from the statement's response list based - on the value of the response text. - - :param response_text: The text of the response to be removed. - :type response_text: str - """ - for response in self.in_response_to: - if response_text == response.text: - self.in_response_to.remove(response) - return True - return False - - def get_response_count(self, statement): - """ - Find the number of times that the statement has been used - as a response to the current statement. - - :param statement: The statement object to get the count for. - :type statement: `Statement` - - :returns: Return the number of times the statement has been used as a response. - :rtype: int - """ - for response in self.in_response_to: - if statement.text == response.text: - return response.occurrence - - return 0 - def serialize(self): - """ - :returns: A dictionary representation of the statement object. - :rtype: dict - """ - data = {} - - data['text'] = self.text - data['in_response_to'] = [] - data['created_at'] = self.created_at - data['extra_data'] = self.extra_data - - for response in self.in_response_to: - data['in_response_to'].append(response.serialize()) - - return data - - @property - def response_statement_cache(self): - """ - This property is to allow ChatterBot Statement objects to - be swappable with Django Statement models. - """ - return self.in_response_to + serializer = StatementSerializer() + return serializer.serialize(self) class InvalidTypeException(Exception): diff --git a/chatterbot/ext/django_chatterbot/admin.py b/chatterbot/ext/django_chatterbot/admin.py index 5d337eb8c..75be6a8af 100644 --- a/chatterbot/ext/django_chatterbot/admin.py +++ b/chatterbot/ext/django_chatterbot/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from chatterbot.ext.django_chatterbot.models import Statement, Response, Conversation +from chatterbot.ext.django_chatterbot.models import Statement, Conversation class StatementAdmin(admin.ModelAdmin): @@ -8,15 +8,9 @@ class StatementAdmin(admin.ModelAdmin): search_fields = ('text', ) -class ResponseAdmin(admin.ModelAdmin): - list_display = ('statement', 'occurrence', ) - - class ConversationAdmin(admin.ModelAdmin): - list_display = ('statement', 'occurrence', ) - list_display = ('root', ) + pass admin.site.register(Statement, StatementAdmin) -admin.site.register(Response, ResponseAdmin) admin.site.register(Conversation, ConversationAdmin) diff --git a/chatterbot/ext/django_chatterbot/migrations/0007_remove_root_related_name.py b/chatterbot/ext/django_chatterbot/migrations/0007_remove_root_related_name.py new file mode 100644 index 000000000..acbd1dc37 --- /dev/null +++ b/chatterbot/ext/django_chatterbot/migrations/0007_remove_root_related_name.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11a1 on 2017-01-22 12:37 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_chatterbot', '0006_create_conversation'), + ] + + operations = [ + migrations.AlterField( + model_name='conversation', + name='root', + field=models.OneToOneField(help_text=b'The initiating statement in a conversation.', on_delete=django.db.models.deletion.CASCADE, related_name='+', to='django_chatterbot.Statement'), + ), + ] diff --git a/chatterbot/ext/django_chatterbot/migrations/0008_statement_conversation.py b/chatterbot/ext/django_chatterbot/migrations/0008_statement_conversation.py new file mode 100644 index 000000000..8f7ca3992 --- /dev/null +++ b/chatterbot/ext/django_chatterbot/migrations/0008_statement_conversation.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11a1 on 2017-01-22 12:40 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_chatterbot', '0007_remove_root_related_name'), + ] + + operations = [ + migrations.AddField( + model_name='statement', + name='conversation', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='conversation', to='django_chatterbot.Conversation'), + ), + ] diff --git a/chatterbot/ext/django_chatterbot/migrations/0009_conversation_root_null.py b/chatterbot/ext/django_chatterbot/migrations/0009_conversation_root_null.py new file mode 100644 index 000000000..684200ed2 --- /dev/null +++ b/chatterbot/ext/django_chatterbot/migrations/0009_conversation_root_null.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11a1 on 2017-01-22 13:35 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_chatterbot', '0008_statement_conversation'), + ] + + operations = [ + migrations.AlterField( + model_name='conversation', + name='root', + field=models.OneToOneField(help_text=b'The initiating statement in a conversation.', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='django_chatterbot.Statement'), + ), + ] diff --git a/chatterbot/ext/django_chatterbot/migrations/0010_conversation_statements.py b/chatterbot/ext/django_chatterbot/migrations/0010_conversation_statements.py new file mode 100644 index 000000000..96d12587f --- /dev/null +++ b/chatterbot/ext/django_chatterbot/migrations/0010_conversation_statements.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11a1 on 2017-01-22 14:50 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_chatterbot', '0009_conversation_root_null'), + ] + + operations = [ + migrations.RemoveField( + model_name='conversation', + name='root', + ), + migrations.RemoveField( + model_name='statement', + name='conversation', + ), + migrations.AddField( + model_name='conversation', + name='statements', + field=models.ManyToManyField(help_text=b'The statements in this conversation.', null=True, related_name='conversation', to='django_chatterbot.Statement'), + ), + ] diff --git a/chatterbot/ext/django_chatterbot/migrations/0011_remove_statement_text_unique.py b/chatterbot/ext/django_chatterbot/migrations/0011_remove_statement_text_unique.py new file mode 100644 index 000000000..74a58165e --- /dev/null +++ b/chatterbot/ext/django_chatterbot/migrations/0011_remove_statement_text_unique.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11a1 on 2017-01-25 01:49 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_chatterbot', '0010_conversation_statements'), + ] + + operations = [ + migrations.AlterField( + model_name='statement', + name='text', + field=models.CharField(max_length=255), + ), + ] diff --git a/chatterbot/ext/django_chatterbot/models.py b/chatterbot/ext/django_chatterbot/models.py index e4a931c90..8269cfad8 100644 --- a/chatterbot/ext/django_chatterbot/models.py +++ b/chatterbot/ext/django_chatterbot/models.py @@ -1,15 +1,18 @@ from django.db import models from django.utils import timezone +from chatterbot.conversation.session import ConversationModelMixin +from chatterbot.conversation.statement import StatementModelMixin -class Statement(models.Model): +class Statement(StatementModelMixin, models.Model): """ A statement represents a single spoken entity, sentence or phrase that someone can say. """ + collection_name = 'statements' + text = models.CharField( - unique=True, blank=False, null=False, max_length=255 @@ -22,6 +25,12 @@ class Statement(models.Model): extra_data = models.CharField(max_length=500) + response = models.OneToOneField( + 'Statement', + related_name='in_response_to', + null=True + ) + # This is the confidence with which the chat bot believes # this is an accurate response. This value is set when the # statement is returned by the chat bot. @@ -40,12 +49,11 @@ def __init__(self, *args, **kwargs): # Responses to be saved if the statement is updated with the storage adapter self.response_statement_cache = [] - @property - def in_response_to(self): + def responses(self): """ - Return the response objects that are for this statement. + Return a list of statements that are known responses to this statement. """ - return Response.objects.filter(statement=self) + return Statement.objects.filter(in_response_to__text=self.text) def add_extra_data(self, key, value): """ @@ -61,45 +69,6 @@ def add_extra_data(self, key, value): self.extra_data = json.dumps(extra_data) - def add_response(self, statement): - """ - Add a response to this statement. - """ - self.response_statement_cache.append(statement) - - def remove_response(self, response_text): - """ - Removes a response from the statement's response list based - on the value of the response text. - - :param response_text: The text of the response to be removed. - :type response_text: str - """ - is_deleted = False - response = self.in_response.filter(response__text=response_text) - - if response.exists(): - is_deleted = True - - return is_deleted - - def get_response_count(self, statement): - """ - Find the number of times that the statement has been used - as a response to the current statement. - - :param statement: The statement object to get the count for. - :type statement: chatterbot.conversation.statement.Statement - - :returns: Return the number of times the statement has been used as a response. - :rtype: int - """ - try: - response = self.in_response.get(response__text=statement.text) - return response.occurrence - except Response.DoesNotExist: - return 0 - def serialize(self): """ :returns: A dictionary representation of the statement object. @@ -112,68 +81,24 @@ def serialize(self): self.extra_data = '{}' data['text'] = self.text - data['in_response_to'] = [] + data['in_response_to'] = {'text': self.in_response_to.text} data['created_at'] = self.created_at data['extra_data'] = json.loads(self.extra_data) - for response in self.in_response.all(): - data['in_response_to'].append(response.serialize()) - return data -class Response(models.Model): - """ - Connection between a response and the statement that triggered it. - - Comparble to a ManyToMany "through" table, but without the M2M indexing/relations. - The text and number of times the response has occurred are stored. - """ - - statement = models.ForeignKey( - 'Statement', - related_name='in_response' - ) - - response = models.ForeignKey( - 'Statement', - related_name='responses' - ) - - unique_together = (('statement', 'response'),) - - occurrence = models.PositiveIntegerField(default=1) - - def __str__(self): - statement = self.statement.text - response = self.response.text - return '{} => {}'.format( - statement if len(statement) <= 20 else statement[:17] + '...', - response if len(response) <= 40 else response[:37] + '...' - ) - - def serialize(self): - """ - :returns: A dictionary representation of the statement object. - :rtype: dict - """ - data = {} - - data['text'] = self.response.text - data['occurrence'] = self.occurrence - - return data - - -class Conversation(models.Model): +class Conversation(ConversationModelMixin, models.Model): """ A sequence of statements representing a conversation. """ - root = models.OneToOneField( + collection_name = 'conversations' + + statements = models.ManyToManyField( 'Statement', related_name='conversation', - help_text='The initiating statement in a conversation.' + help_text='The statements in this conversation.' ) def __str__(self): diff --git a/chatterbot/ext/django_chatterbot/views.py b/chatterbot/ext/django_chatterbot/views.py index 7f3ab81ba..e4a239124 100644 --- a/chatterbot/ext/django_chatterbot/views.py +++ b/chatterbot/ext/django_chatterbot/views.py @@ -32,8 +32,8 @@ def get_chat_session(self, request): chat_session = self.chatterbot.conversation_sessions.get(chat_session_id, None) if not chat_session: - chat_session = self.chatterbot.conversation_sessions.new() - request.session['chat_session_id'] = chat_session.id_string + chat_session = self.chatterbot.conversation_sessions.create() + request.session['chat_session_id'] = chat_session.id return chat_session @@ -43,16 +43,13 @@ class ChatterBotView(ChatterBotViewMixin, View): Provide an API endpoint to interact with ChatterBot. """ - def _serialize_conversation(self, session): - if session.conversation.empty(): - return [] + def _serialize_conversation(self, conversation): + statements = [] - conversation = [] + for statement in conversation.statements.all(): + statements.append(statement.serialize()) - for statement, response in session.conversation: - conversation.append([statement.serialize(), response.serialize()]) - - return conversation + return statements def post(self, request, *args, **kwargs): """ @@ -62,9 +59,14 @@ def post(self, request, *args, **kwargs): self.validate(input_data) + # Convert the extra_data to a string to be stored by the Django model + if 'extra_data' in input_data: + extra_data = input_data['extra_data'] + input_data['extra_data'] = json.dumps(extra_data) + chat_session = self.get_chat_session(request) - response = self.chatterbot.get_response(input_data, chat_session.id_string) + response = self.chatterbot.get_response(input_data, chat_session.id) response_data = response.serialize() try: diff --git a/chatterbot/filters.py b/chatterbot/filters.py index 1b45574d6..0c63362e9 100644 --- a/chatterbot/filters.py +++ b/chatterbot/filters.py @@ -28,13 +28,22 @@ def filter_selection(self, chatterbot, session_id): session = chatterbot.conversation_sessions.get(session_id) - if session.conversation.empty(): + # Check if a conversation of some length exists + if not session.statements.exists(): return chatterbot.storage.base_query text_of_recent_responses = [] - for statement, response in session.conversation: - text_of_recent_responses.append(response.text) + skip = True + + for statement in session.statements.all(): + + # Skip every other statement to only filter out the bot's responses + if skip: + skip = False + else: + skip = True + text_of_recent_responses.append(statement.text) query = chatterbot.storage.base_query.statement_text_not_in( text_of_recent_responses diff --git a/chatterbot/input/variable_input_type_adapter.py b/chatterbot/input/variable_input_type_adapter.py index 3c7ca09fc..e43f24c0d 100644 --- a/chatterbot/input/variable_input_type_adapter.py +++ b/chatterbot/input/variable_input_type_adapter.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals from chatterbot.input import InputAdapter -from chatterbot.conversation import Statement class VariableInputTypeAdapter(InputAdapter): @@ -45,15 +44,15 @@ def process_input(self, statement): # Convert the input string into a statement object if input_type == self.TEXT: - return Statement(statement) + return self.chatbot.storage.Statement(text=statement) # Convert input dictionary into a statement object if input_type == self.JSON: input_json = dict(statement) - text = input_json["text"] - del(input_json["text"]) + text = input_json['text'] + del input_json['text'] - return Statement(text, **input_json) + return self.chatbot.storage.Statement(text=text, **input_json) class UnrecognizedInputFormatException(Exception): """ diff --git a/chatterbot/logic/best_match.py b/chatterbot/logic/best_match.py index 079f76172..b90af6668 100644 --- a/chatterbot/logic/best_match.py +++ b/chatterbot/logic/best_match.py @@ -58,6 +58,7 @@ def process(self, input_statement): # Get all statements that are in response to the closest match response_list = self.chatbot.storage.filter( + self.chatbot.storage.Statement, in_response_to__contains=closest_match.text ) diff --git a/chatterbot/queues.py b/chatterbot/queues.py deleted file mode 100644 index 37cdf2647..000000000 --- a/chatterbot/queues.py +++ /dev/null @@ -1,80 +0,0 @@ -class FixedSizeQueue(object): - """ - This is a data structure like a queue. - Only a fixed number of items can be added. - Once the maximum is reached, when a new item is - added the oldest item in the queue will be removed. - """ - - def __init__(self, maxsize=10): - self.maxsize = maxsize - self.queue = [] - - def append(self, item): - """ - Append an element at the end of the queue. - """ - if len(self.queue) == self.maxsize: - # Remove an element from the top of the list - self.queue.pop(0) - - self.queue.append(item) - - def __len__(self): - return len(self.queue) - - def __getitem__(self, index): - return self.queue[index] - - def __contains__(self, item): - """ - Check if an element is in this queue. - """ - return item in self.queue - - def empty(self): - """ - Return True if the queue is empty, False otherwise. - """ - return len(self.queue) == 0 - - def peek(self): - """ - Return the most recent item put in the queue. - """ - if self.empty(): - return None - return self.queue[-1] - - def flush(self): - """ - Remove all elements from the queue. - """ - self.queue = [] - - -class ResponseQueue(FixedSizeQueue): - """ - An extension of the FixedSizeQueue class with - utility methods to help manage the conversation. - """ - - def get_last_response_statement(self): - """ - Return the last statement that was received. - """ - previous_interaction = self.peek() - if previous_interaction: - # Return the output statement - return previous_interaction[1] - return None - - def get_last_input_statement(self): - """ - Return the last response that was given. - """ - previous_interaction = self.peek() - if previous_interaction: - # Return the input statement - return previous_interaction[0] - return None diff --git a/chatterbot/storage/django_storage.py b/chatterbot/storage/django_storage.py index 4c1dd7005..82ea75eab 100644 --- a/chatterbot/storage/django_storage.py +++ b/chatterbot/storage/django_storage.py @@ -25,12 +25,12 @@ def find(self, statement_text): self.logger.info(str(e)) return None - def filter(self, **kwargs): + def filter(self, obj, **kwargs): """ Returns a list of statements in the database that match the parameters specified. """ - from chatterbot.ext.django_chatterbot.models import Statement + from chatterbot.ext.django_chatterbot.models import Statement, Conversation from django.db.models import Q order = kwargs.pop('order_by', None) @@ -66,49 +66,67 @@ def filter(self, **kwargs): value = kwargs['in_response__response__text'] parameters['responses__statement__text'] = value - statements = Statement.objects.filter(Q(**kwargs) | Q(**parameters)) + if obj.collection_name == 'statements': + results = Statement.objects.filter(Q(**kwargs) | Q(**parameters)) + else: + results = Conversation.objects.filter(**kwargs) if order: - statements = statements.order_by(order) + results = results.order_by(order) - return statements + return results - def update(self, statement): + def update(self, obj): """ - Update the provided statement. + Update the provided object. """ - from chatterbot.ext.django_chatterbot.models import Statement + from chatterbot.ext.django_chatterbot.models import Statement, Conversation + + if obj.collection_name == 'statements': - response_statement_cache = statement.response_statement_cache + existing_statements = Statement.objects.filter(text=obj.text) - statement, created = Statement.objects.get_or_create(text=statement.text) - statement.extra_data = getattr(statement, 'extra_data', '') - statement.save() + if existing_statements.exists(): + statement = existing_statements.first() + statement.extra_data = getattr(obj, 'extra_data', '') + else: + obj.save() + statement = obj + + for _response_statement in obj.response_statement_cache: - for _response_statement in response_statement_cache: + existing_responses = Statement.objects.filter( + text=_response_statement.text + ) + if existing_responses.exists(): + response_statement = existing_responses.first() + else: + response_statement = Statement( + text=_response_statement.text + ) - response_statement, created = Statement.objects.get_or_create( - text=_response_statement.text - ) - response_statement.extra_data = getattr(_response_statement, 'extra_data', '') - response_statement.save() + response_statement.extra_data = getattr(_response_statement, 'extra_data', '') + response_statement.save() - response, created = statement.in_response.get_or_create( - statement=statement, - response=response_statement - ) + response, created = statement.in_response.get_or_create( + statement=statement, + response=response_statement + ) - if not created: - response.occurrence += 1 - response.save() + if not created: + response.occurrence += 1 + response.save() + else: + obj.save() - return statement + return obj def get_random(self): """ Returns a random statement from the database """ from chatterbot.ext.django_chatterbot.models import Statement + return Statement.objects.order_by('?').first() def remove(self, statement_text): @@ -118,25 +136,15 @@ def remove(self, statement_text): input text. """ from chatterbot.ext.django_chatterbot.models import Statement - from chatterbot.ext.django_chatterbot.models import Response - from django.db.models import Q - statements = Statement.objects.filter(text=statement_text) - - responses = Response.objects.filter( - Q(statement__text=statement_text) | Q(response__text=statement_text) - ) - - responses.delete() - statements.delete() + Statement.objects.filter(text=statement_text).delete() def drop(self): """ Remove all data from the database. """ - from chatterbot.ext.django_chatterbot.models import Statement, Response, Conversation + from chatterbot.ext.django_chatterbot.models import Statement, Conversation Statement.objects.all().delete() - Response.objects.all().delete() Conversation.objects.all().delete() def get_response_statements(self): @@ -146,8 +154,6 @@ def get_response_statements(self): in_response_to field. Otherwise, the logic adapter may find a closest matching statement that does not have a known response. """ - from chatterbot.ext.django_chatterbot.models import Statement, Response - - responses = Response.objects.all() + from chatterbot.ext.django_chatterbot.models import Statement - return Statement.objects.filter(in_response__in=responses) + return Statement.objects.filter(in_response_to__isnull=False) diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index c81ef8941..0efa00311 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -1,6 +1,5 @@ import warnings from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Response class JsonFileStorageAdapter(StorageAdapter): @@ -18,6 +17,7 @@ class JsonFileStorageAdapter(StorageAdapter): def __init__(self, **kwargs): super(JsonFileStorageAdapter, self).__init__(**kwargs) + import os from jsondb import Database if not kwargs.get('silence_performance_warning', False): @@ -26,20 +26,21 @@ def __init__(self, **kwargs): self.UnsuitableForProductionWarning ) - database_path = self.kwargs.get('database', 'database.db') - self.database = Database(database_path) + statement_database_path = self.kwargs.get('database', 'database.db') + conversation_database_path = 'conversations.db' + self.database = { + 'statements': Database(statement_database_path), + 'conversations': Database(conversation_database_path) + } - self.adapter_supports_queries = False - - def _keys(self): - # The value has to be cast as a list for Python 3 compatibility - return list(self.database[0].keys()) + # Create the statements document as an empty list + # self.database['statements'] = [] def count(self): - return len(self._keys()) + return len(self.database['statements'].keys()) def find(self, statement_text): - values = self.database.data(key=statement_text) + values = self.database['statements'].data(key=statement_text) if not values: return None @@ -48,53 +49,48 @@ def find(self, statement_text): return self.json_to_object(values) - def remove(self, statement_text): + def remove(self, obj): """ Removes the statement that matches the input text. Removes any responses from statements if the response text matches the input text. """ - for statement in self.filter(in_response_to__contains=statement_text): - statement.remove_response(statement_text) + for statement in self.filter(obj, in_response_to__contains=obj.text): + statement.remove_response(obj.text) self.update(statement) - self.database.delete(statement_text) + self.database['statements'].delete(obj.text) - def deserialize_responses(self, response_list): + def json_to_object(self, object_data): """ - Takes the list of response items and returns - the list converted to Response objects. + Converts a dictionary-like object to a Statement object. """ - proxy_statement = self.Statement('') - - for response in response_list: - data = response.copy() - text = data['text'] - del data['text'] - proxy_statement.add_response( - Response(text, **data) - ) + # Don't modify the referenced object + object_data = object_data.copy() - return proxy_statement.in_response_to + if 'text' in object_data: - def json_to_object(self, statement_data): - """ - Converts a dictionary-like object to a Statement object. - """ + # Build the objects for the response list + object_data['in_response_to'] = self.Statement( + **object_data['in_response_to'] + ) - # Don't modify the referenced object - statement_data = statement_data.copy() + # Remove the text attribute from the values + text = object_data.pop('text') - # Build the objects for the response list - statement_data['in_response_to'] = self.deserialize_responses( - statement_data['in_response_to'] - ) + return self.Statement(text, **object_data) + else: + statements = [] - # Remove the text attribute from the values - text = statement_data.pop('text') + for statement_data in object_data['statements']: + text = statement_data.pop('text') + statements.append(self.Statement(text, **statement_data)) - return self.Statement(text, **statement_data) + return self.Conversation( + id=object_data['id'], + statements=statements + ) def _all_kwargs_match_values(self, kwarguments, values): for kwarg in kwarguments: @@ -120,7 +116,7 @@ def _all_kwargs_match_values(self, kwarguments, values): return True - def filter(self, **kwargs): + def filter(self, obj, **kwargs): """ Returns a list of statements in the database that match the parameters specified. @@ -131,14 +127,10 @@ def filter(self, **kwargs): order_by = kwargs.pop('order_by', None) - for key in self._keys(): - values = self.database.data(key=key) - - # Add the text attribute to the values - values['text'] = key + for statement in self.database['statements']: - if self._all_kwargs_match_values(kwargs, values): - results.append(self.json_to_object(values)) + if self._all_kwargs_match_values(kwargs, statement): + results.append(self.json_to_object(statement)) if order_by: @@ -150,24 +142,22 @@ def filter(self, **kwargs): return results - def update(self, statement): + def update(self, obj): """ - Update a statement in the database. + Update the object in the database. """ - data = statement.serialize() + data = obj.serialize() - # Remove the text key from the data - del data['text'] - self.database.data(key=statement.text, value=data) + if 'text' in data: + statements = self.database['statements'] + statements.append(statement.serialize()) - # Make sure that an entry for each response exists - for response_statement in statement.in_response_to: - response = self.find(response_statement.text) - if not response: - response = self.Statement(response_statement.text) - self.update(response) + self.database.data(key='statements', value=statements) + else: + del data['id'] + self.database['conversations'].data(key=obj.id, value=data) - return statement + return obj def get_random(self): from random import choice @@ -175,14 +165,14 @@ def get_random(self): if self.count() < 1: raise self.EmptyDatabaseException() - statement = choice(self._keys()) - return self.find(statement) + return choice(self.database['statements']) def drop(self): """ Remove the json file database completely. """ - self.database.drop() + self.database['statements'].drop() + self.database['conversations'].drop() class UnsuitableForProductionWarning(Warning): """ diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index e025c9c3e..c96452d99 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -1,5 +1,4 @@ from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Response class Query(object): @@ -43,17 +42,14 @@ def statement_response_list_contains(self, statement_text): if 'in_response_to' not in query: query['in_response_to'] = {} - if '$elemMatch' not in query['in_response_to']: - query['in_response_to']['$elemMatch'] = {} - - query['in_response_to']['$elemMatch']['text'] = statement_text + query['in_response_to']['text'] = statement_text return Query(query) - def statement_response_list_equals(self, response_list): + def statement_response_list_equals(self, response): query = self.query.copy() - query['in_response_to'] = response_list + query['in_response_to'] = response return Query(query) @@ -95,66 +91,65 @@ def __init__(self, **kwargs): # Specify the name of the database self.database = self.client[self.database_name] - # The mongo collection of statement documents - self.statements = self.database['statements'] - - # Set a requirement for the text attribute to be unique - self.statements.create_index('text', unique=True) - self.base_query = Query() def count(self): - return self.statements.count() + return self.database['statements'].count() def find(self, statement_text): query = self.base_query.statement_text_equals(statement_text) - values = self.statements.find_one(query.value()) + values = self.database['statements'].find_one(query.value()) if not values: return None del values['text'] - # Build the objects for the response list - values['in_response_to'] = self.deserialize_responses( - values.get('in_response_to', []) - ) + if 'in_response_to' in values: + values['in_response_to'] = self.deserialize_responses( + values['in_response_to'] + ) return self.Statement(statement_text, **values) - def deserialize_responses(self, response_list): + def deserialize_responses(self, statement_data): """ Takes the list of response items and returns the list converted to Response objects. """ - proxy_statement = self.Statement('') - - for response in response_list: - text = response['text'] - del response['text'] - - proxy_statement.add_response( - Response(text, **response) - ) + text = statement_data['text'] + del statement_data['text'] - return proxy_statement.in_response_to + return self.Statement(text, **statement_data) - def mongo_to_object(self, statement_data): + def mongo_to_object(self, object_data): """ Return Statement object when given data returned from Mongo DB. """ - statement_text = statement_data['text'] - del statement_data['text'] + if 'text' in object_data: + statement_text = object_data['text'] + del object_data['text'] - statement_data['in_response_to'] = self.deserialize_responses( - statement_data.get('in_response_to', []) - ) + if 'in_response_to' in object_data: + object_data['in_response_to'] = self.deserialize_responses( + object_data['in_response_to'] + ) - return self.Statement(statement_text, **statement_data) + return self.Statement(statement_text, **object_data) + else: + statements = [] + + for statement_data in object_data.get('statements'): + statements.append(self.mongo_to_object(statement_data)) + + return self.Conversation( + id=object_data['id'], + statements=statements + ) - def filter(self, **kwargs): + def filter(self, obj, **kwargs): """ Returns a list of statements in the database that match the parameters specified. @@ -165,11 +160,9 @@ def filter(self, **kwargs): order_by = kwargs.pop('order_by', None) - # Convert Response objects to data + # Convert response statement objects to data if 'in_response_to' in kwargs: - serialized_responses = [] - for response in kwargs['in_response_to']: - serialized_responses.append({'text': response}) + serialized_response = {'text': kwargs['in_response_to']} query = query.statement_response_list_equals(serialized_responses) del kwargs['in_response_to'] @@ -182,7 +175,7 @@ def filter(self, **kwargs): query = query.raw(kwargs) - matches = self.statements.find(query.value()) + matches = self.database[obj.collection_name].find(query.value()) if order_by: @@ -201,40 +194,39 @@ def filter(self, **kwargs): return results - def update(self, statement): + def update(self, obj): from pymongo import UpdateOne from pymongo.errors import BulkWriteError - data = statement.serialize() + data = obj.serialize() operations = [] update_operation = UpdateOne( - {'text': statement.text}, + {obj.pk_field: getattr(obj, obj.pk_field)}, {'$set': data}, upsert=True ) operations.append(update_operation) - # Make sure that an entry for each response is saved - for response_dict in data.get('in_response_to', []): - response_text = response_dict.get('text') + # Make sure that the response is saved + response_data = data.get('in_response_to', None) - # $setOnInsert does nothing if the document is not created + if response_data: update_operation = UpdateOne( - {'text': response_text}, - {'$set': response_dict}, + {'text': response_data['text']}, + {'$set': response_data}, upsert=True ) operations.append(update_operation) try: - self.statements.bulk_write(operations, ordered=False) + self.database[obj.collection_name].bulk_write(operations, ordered=False) except BulkWriteError as bwe: # Log the details of a bulk write error self.logger.error(str(bwe.details)) - return statement + return obj def get_random(self): """ @@ -249,21 +241,21 @@ def get_random(self): random_integer = randint(0, count - 1) - statements = self.statements.find().limit(1).skip(random_integer) + statements = self.database['statements'].find().limit(1).skip(random_integer) return self.mongo_to_object(list(statements)[0]) - def remove(self, statement_text): + def remove(self, obj): """ Removes the statement that matches the input text. Removes any responses from statements if the response text matches the input text. """ - for statement in self.filter(in_response_to__contains=statement_text): - statement.remove_response(statement_text) + for statement in self.filter(obj, in_response_to__contains=obj.text): + statement.remove_response(obj.text) self.update(statement) - self.statements.delete_one({'text': statement_text}) + self.database[obj.collection_name].delete_one({'text': obj.text}) def get_response_statements(self): """ @@ -272,7 +264,8 @@ def get_response_statements(self): in_response_to field. Otherwise, the logic adapter may find a closest matching statement that does not have a known response. """ - response_query = self.statements.distinct('in_response_to.text') + statements = self.database['statements'] + response_query = statements.distinct('in_response_to.text') _statement_query = { 'text': { @@ -282,7 +275,7 @@ def get_response_statements(self): _statement_query.update(self.base_query.value()) - statement_query = self.statements.find(_statement_query) + statement_query = statements.find(_statement_query) statement_objects = [] diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index decd49a96..ecba9ef7b 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -1,4 +1,5 @@ import logging +import os class StorageAdapter(object): @@ -16,21 +17,22 @@ def __init__(self, base_query=None, *args, **kwargs): self.adapter_supports_queries = True self.base_query = None - @property - def Statement(self): - """ - Create a storage-aware statement. - """ - import os + # Set up the class for storage-aware statements and conversations if 'DJANGO_SETTINGS_MODULE' in os.environ: - from chatterbot.ext.django_chatterbot.models import Statement - return Statement + from chatterbot.ext.django_chatterbot.models import Statement, Conversation + + self.Statement = Statement + self.Conversation = Conversation else: from chatterbot.conversation.statement import Statement - statement = Statement - statement.storage = self - return statement + from chatterbot.conversation.session import Conversation, ConversationManager + + self.Statement = Statement + self.Statement.storage = self + + self.Conversation = Conversation + self.Conversation.objects = ConversationManager(self) def generate_base_query(self, chatterbot, session_id): """ @@ -66,7 +68,7 @@ def remove(self, statement_text): 'The `remove` method is not implemented by this adapter.' ) - def filter(self, **kwargs): + def filter(self, obj, **kwargs): """ Returns a list of objects from the database. The kwargs parameter can contain any number @@ -78,7 +80,7 @@ def filter(self, **kwargs): 'The `filter` method is not implemented by this adapter.' ) - def update(self, statement): + def update(self, obj): """ Modifies an entry in the database. Creates an entry if one does not exist. @@ -113,7 +115,8 @@ def get_response_statements(self): This method may be overridden by a child class to provide more a efficient method to get these results. """ - statement_list = self.filter() + from chatterbot.conversation import Statement + statement_list = self.filter(Statement) responses = set() to_remove = list() diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 6946c437a..2cfd13bff 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -1,5 +1,4 @@ import logging -from .conversation import Statement, Response class Trainer(object): @@ -25,7 +24,7 @@ def get_or_create(self, statement_text): statement = self.storage.find(statement_text) if not statement: - statement = Statement(statement_text) + statement = self.storage.Statement(text=statement_text) return statement @@ -48,9 +47,11 @@ def __str__(self): def _generate_export_data(self): result = [] - for statement in self.storage.filter(): - for response in statement.in_response_to: - result.append([response.text, statement.text]) + for statement in self.storage.filter(self.storage.Statement): + if statement.in_response_to: + result.append([statement.text, statement.in_response_to.text]) + else: + result.append([statement.text]) return result @@ -82,9 +83,7 @@ def train(self, conversation): statement = self.get_or_create(text) if statement_history: - statement.add_response( - Response(statement_history[-1].text) - ) + statement.in_response_to = statement_history[-1] statement_history.append(statement) self.storage.update(statement) @@ -192,12 +191,13 @@ def get_statements(self): self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word)) tweets = self.api.GetSearch(term=random_word, count=50) for tweet in tweets: - statement = Statement(tweet.text) + statement = self.storage.Statement(text=tweet.text) if tweet.in_reply_to_status_id: try: status = self.api.GetStatus(tweet.in_reply_to_status_id) - statement.add_response(Response(status.text)) + statement.in_response_to = self.storage.Statement(text=status.text) + statements.append(statement) except TwitterError as error: self.logger.warning(str(error)) @@ -346,9 +346,7 @@ def train(self): statement.add_extra_data('addressing_speaker', row[2]) if statement_history: - statement.add_response( - Response(statement_history[-1].text) - ) + statement.in_response_to = statement_history[-1] statement_history.append(statement) self.storage.update(statement) diff --git a/docs/sessions.rst b/docs/sessions.rst index 5c3fe87e6..682d06040 100644 --- a/docs/sessions.rst +++ b/docs/sessions.rst @@ -7,18 +7,10 @@ A chat session is where the chat bot interacts with a person, and supporting multiple chat sessions means that your chat bot can have multiple different conversations with different people at the same time. -.. autoclass:: chatterbot.conversation.session.Session +.. autoclass:: chatterbot.conversation.session.Conversation :members: -.. autoclass:: chatterbot.conversation.session.ConversationSessionManager - :members: - -Each session object holds a queue of the most recent communications that have -occured durring that session. The queue holds tuples with two values each, -the first value is the input that the bot recieved and the second value is the -response that the bot returned. - -.. autoclass:: chatterbot.queues.ResponseQueue +.. autoclass:: chatterbot.conversation.session.ConversationManager :members: Session scope diff --git a/examples/django_app/tests/test_example.py b/examples/django_app/tests/test_example.py index 7f39b6a88..a2bed5a73 100644 --- a/examples/django_app/tests/test_example.py +++ b/examples/django_app/tests/test_example.py @@ -75,20 +75,19 @@ class ApiIntegrationTestCase(TestCase): def setUp(self): super(ApiIntegrationTestCase, self).setUp() + from chatterbot.ext.django_chatterbot.models import Conversation + self.api_url = reverse('chatterbot:chatterbot') - # Clear the response queue before tests - ChatterBotView.chatterbot.conversation_sessions.get( - ChatterBotView.chatterbot.default_session.id_string - ).conversation.flush() + # Clear the conversation history before tests + Conversation.objects.all().delete() def tearDown(self): super(ApiIntegrationTestCase, self).tearDown() + from chatterbot.ext.django_chatterbot.models import Conversation - # Clear the response queue after tests - ChatterBotView.chatterbot.conversation_sessions.get( - ChatterBotView.chatterbot.default_session.id_string - ).conversation.flush() + # Clear the conversation history after tests + Conversation.objects.all().delete() def _get_json(self, response): return json.loads(force_text(response.content)) @@ -113,6 +112,4 @@ def test_get_conversation(self): self.assertIn('conversation', data) self.assertEqual(len(data['conversation']), 1) - self.assertEqual(len(data['conversation'][0]), 2) - self.assertIn('text', data['conversation'][0][0]) - self.assertIn('text', data['conversation'][0][1]) + self.assertIn('text', data['conversation'][0]) diff --git a/examples/learning_feedback_example.py b/examples/learning_feedback_example.py index 4da56abc7..e97a97bbd 100644 --- a/examples/learning_feedback_example.py +++ b/examples/learning_feedback_example.py @@ -57,7 +57,7 @@ def get_feedback(): # Update the conversation history for the bot # It is important that this happens last, after the learning step bot.conversation_sessions.update( - bot.default_session.id_string, + bot.default_session.id, (statement, response, ) ) diff --git a/tests/base_case.py b/tests/base_case.py index d2f0eed2a..11783172b 100644 --- a/tests/base_case.py +++ b/tests/base_case.py @@ -63,3 +63,5 @@ def get_kwargs(self): kwargs['database'] = self.random_string() kwargs['storage_adapter'] = 'chatterbot.storage.MongoDatabaseAdapter' return kwargs + +ChatBotTestCase = ChatBotMongoTestCase diff --git a/tests/conversation_tests/test_responses.py b/tests/conversation_tests/test_responses.py deleted file mode 100644 index f1a71cf81..000000000 --- a/tests/conversation_tests/test_responses.py +++ /dev/null @@ -1,9 +0,0 @@ -from unittest import TestCase -from chatterbot.conversation import Response - - -class ResponseTests(TestCase): - - def setUp(self): - self.response = Response("A test response.") - diff --git a/tests/conversation_tests/test_statements.py b/tests/conversation_tests/test_statements.py index 8192c45fd..986ce7754 100644 --- a/tests/conversation_tests/test_statements.py +++ b/tests/conversation_tests/test_statements.py @@ -1,22 +1,18 @@ # -*- coding: utf-8 -*- from unittest import TestCase -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement class StatementTests(TestCase): - def setUp(self): - self.statement = Statement("A test statement.") - def test_list_equality(self): """ It should be possible to check if a statement exists in the list of statements that another statement has been issued in response to. """ - self.statement.add_response(Response("Yo")) - self.assertEqual(len(self.statement.in_response_to), 1) - self.assertIn(Response("Yo"), self.statement.in_response_to) + statements = [Statement('Hi'), Statement('Hello')] + self.assertEqual(Statement('Hi'), statements) def test_list_equality_unicode(self): """ @@ -24,62 +20,5 @@ def test_list_equality_unicode(self): is in a list of other statements when the statements text is unicode. """ - statements = [Statement("Hello"), Statement("我很好太感谢")] - statement = Statement("我很好太感谢") - self.assertIn(statement, statements) - - def test_update_response_list_new(self): - self.statement.add_response(Response("Hello")) - self.assertTrue(len(self.statement.in_response_to), 1) - - def test_update_response_list_existing(self): - response = Response("Hello") - self.statement.add_response(response) - self.statement.add_response(response) - self.assertTrue(len(self.statement.in_response_to), 1) - - def test_remove_response_exists(self): - self.statement.add_response(Response("Testing")) - removed = self.statement.remove_response("Testing") - self.assertTrue(removed) - - def test_remove_response_does_not_exist(self): - self.statement.add_response(Response("Testing")) - removed = self.statement.remove_response("Test") - self.assertFalse(removed) - - def test_serializer(self): - data = self.statement.serialize() - self.assertEqual(self.statement.text, data["text"]) - - def test_occurrence_count_for_new_statement(self): - """ - When the occurrence is updated for a statement that - previously did not exist as a statement that the current - statement was issued in response to, then the new statement - should be added to the response list and the occurence count - for that response should be set to 1. - """ - response = Response("This is a test.") - - self.statement.add_response(response) - self.assertTrue(self.statement.get_response_count(response), 1) - - def test_occurrence_count_for_existing_statement(self): - self.statement.add_response(Response("ABC")) - self.statement.add_response(Response("ABC")) - self.assertTrue( - self.statement.get_response_count(Response("ABC")), - 2 - ) - - def test_occurrence_count_incremented(self): - self.statement.add_response(Response("ABC")) - self.statement.add_response(Response("ABC")) - - self.assertEqual(len(self.statement.in_response_to), 1) - self.assertEqual(self.statement.in_response_to[0].occurrence, 2) - - def test_add_non_response(self): - with self.assertRaises(Statement.InvalidTypeException): - self.statement.add_response(Statement("Blah")) + statements = [Statement('Hello'), Statement('我很好太感谢')] + self.assertIn(Statement('我很好太感谢'), statements) diff --git a/tests/input_adapter_tests/test_variable_input_type_adapter.py b/tests/input_adapter_tests/test_variable_input_type_adapter.py index eae3ceb23..e826d5ee2 100644 --- a/tests/input_adapter_tests/test_variable_input_type_adapter.py +++ b/tests/input_adapter_tests/test_variable_input_type_adapter.py @@ -1,12 +1,14 @@ -from unittest import TestCase +from tests.base_case import ChatBotTestCase from chatterbot.conversation import Statement from chatterbot.input import VariableInputTypeAdapter -class VariableInputTypeAdapterTests(TestCase): +class VariableInputTypeAdapterTests(ChatBotTestCase): def setUp(self): + super(VariableInputTypeAdapterTests, self).setUp() self.adapter = VariableInputTypeAdapter() + self.adapter.set_chatbot(self.chatbot) def test_statement_returned_dict(self): data = { diff --git a/tests/logic_adapter_tests/best_match_integration_tests/test_levenshtein_distance.py b/tests/logic_adapter_tests/best_match_integration_tests/test_levenshtein_distance.py index 9077b26e3..cf4b8d9f7 100644 --- a/tests/logic_adapter_tests/best_match_integration_tests/test_levenshtein_distance.py +++ b/tests/logic_adapter_tests/best_match_integration_tests/test_levenshtein_distance.py @@ -34,7 +34,8 @@ def test_get_closest_statement(self): Statement('Yuck, black licorice jelly beans.', in_response_to=[Response('What is the meaning of life?')]), Statement('I hear you are going on a quest?', in_response_to=[Response('Who do you love?')]), ] - self.adapter.chatbot.storage.filter = MagicMock(return_value=possible_choices) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) statement = Statement('What is your quest?') @@ -46,7 +47,8 @@ def test_confidence_exact_match(self): possible_choices = [ Statement('What is your quest?', in_response_to=[Response('What is your quest?')]) ] - self.adapter.chatbot.storage.filter = MagicMock(return_value=possible_choices) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) statement = Statement('What is your quest?') match = self.adapter.get(statement) @@ -57,7 +59,8 @@ def test_confidence_half_match(self): possible_choices = [ Statement('xxyy', in_response_to=[Response('xxyy')]) ] - self.adapter.chatbot.storage.filter = MagicMock(return_value=possible_choices) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) statement = Statement('wwxx') match = self.adapter.get(statement) @@ -68,7 +71,8 @@ def test_confidence_no_match(self): possible_choices = [ Statement('xxx', in_response_to=[Response('xxx')]) ] - self.adapter.chatbot.storage.filter = MagicMock(return_value=possible_choices) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) statement = Statement('yyy') match = self.adapter.get(statement) diff --git a/tests/logic_adapter_tests/best_match_integration_tests/test_synset_distance.py b/tests/logic_adapter_tests/best_match_integration_tests/test_synset_distance.py index 005b1ef29..67facea17 100644 --- a/tests/logic_adapter_tests/best_match_integration_tests/test_synset_distance.py +++ b/tests/logic_adapter_tests/best_match_integration_tests/test_synset_distance.py @@ -37,9 +37,8 @@ def test_get_closest_statement(self): Statement('This is a beautiful swamp.', in_response_to=[Response('This is a beautiful swamp.')]), Statement('It smells like a swamp.', in_response_to=[Response('It smells like a swamp.')]) ] - self.adapter.chatbot.storage.filter = MagicMock( - return_value=possible_choices - ) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) statement = Statement('This is a lovely swamp.') match = self.adapter.get(statement) diff --git a/tests/logic_adapter_tests/test_low_confidence_adapter.py b/tests/logic_adapter_tests/test_low_confidence_adapter.py index cda7526f9..23ac001d4 100644 --- a/tests/logic_adapter_tests/test_low_confidence_adapter.py +++ b/tests/logic_adapter_tests/test_low_confidence_adapter.py @@ -1,5 +1,4 @@ from unittest import TestCase -from mock import MagicMock from chatterbot.logic import LowConfidenceAdapter from chatterbot.conversation import Statement, Response from tests.base_case import ChatBotTestCase @@ -37,7 +36,8 @@ def setUp(self): Response('Who do you love?') ]), ] - self.adapter.chatbot.storage.filter = MagicMock(return_value=possible_choices) + for choice in possible_choices: + self.adapter.chatbot.storage.update(choice) def test_high_confidence(self): """ diff --git a/tests/storage_adapter_tests/test_json_file_storage_adapter.py b/tests/storage_adapter_tests/test_json_file_storage_adapter.py index 26b0ffce1..4f848fe88 100644 --- a/tests/storage_adapter_tests/test_json_file_storage_adapter.py +++ b/tests/storage_adapter_tests/test_json_file_storage_adapter.py @@ -95,16 +95,12 @@ def test_update_modifies_existing_statement(self): ) # Update the statement value - statement.add_response( - Response("New response") - ) + statement.in_response_to = Statement('New response') self.adapter.update(statement) # Check that the values have changed found_statement = self.adapter.find(statement.text) - self.assertEqual( - len(found_statement.in_response_to), 1 - ) + self.assertEqual(found_statement.in_response_to, statement.in_response_to) def test_get_random_returns_statement(self): statement = Statement("New statement") @@ -158,32 +154,31 @@ def test_update_saves_statement_with_multiple_responses(self): self.assertEqual(len(response.in_response_to), 2) def test_getting_and_updating_statement(self): - statement = Statement("Hi") + statement1 = Statement('Hi', in_response_to=Statement('Hello')) + statement2 = Statement('Hi', in_response_to=Statement('Hello')) self.adapter.update(statement) - statement.add_response(Response("Hello")) - statement.add_response(Response("Hello")) - self.adapter.update(statement) + self.adapter.update(statement1) + self.adapter.update(statement2) - response = self.adapter.find(statement.text) + results = self.adapter.filter(statement.text) - self.assertEqual(len(response.in_response_to), 1) - self.assertEqual(response.in_response_to[0].occurrence, 2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].text, statement1.text) + self.assertEqual(results[1].text, statement2.text) + self.assertEqual(results[0].in_response_to, statement1.in_response_to) + self.assertEqual(results[1].in_response_to, statement2.in_response_to) - def test_deserialize_responses(self): - response_list = [ - {"text": "Test", "occurrence": 3}, - {"text": "Testing", "occurrence": 1}, - ] - results = self.adapter.deserialize_responses(response_list) + def test_deserialize_response(self): + result = self.adapter.deserialize_responses({'text': 'Test'}) - self.assertEqual(len(results), 2) + self.assertEqual(result.text, 'Test') def test_remove(self): text = "Sometimes you have to run before you can walk." statement = Statement(text) self.adapter.update(statement) - self.adapter.remove(statement.text) + self.adapter.remove(statement) result = self.adapter.find(text) self.assertIsNone(result) @@ -195,8 +190,8 @@ def test_remove_response(self): in_response_to=[Response(text)] ) self.adapter.update(statement) - self.adapter.remove(statement.text) - results = self.adapter.filter(in_response_to__contains=text) + self.adapter.remove(statement) + results = self.adapter.filter(Statement, in_response_to__contains=text) self.assertEqual(results, []) @@ -242,7 +237,7 @@ def setUp(self): def test_filter_text_no_matches(self): self.adapter.update(self.statement1) - results = self.adapter.filter(text="Howdy") + results = self.adapter.filter(Statement, text="Howdy") self.assertEqual(len(results), 0) @@ -250,6 +245,7 @@ def test_filter_in_response_to_no_matches(self): self.adapter.update(self.statement1) results = self.adapter.filter( + Statement, in_response_to="Maybe" ) self.assertEqual(len(results), 0) @@ -266,7 +262,7 @@ def test_filter_equal_results(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter(in_response_to=[]) + results = self.adapter.filter(Statement, in_response_to=[]) self.assertEqual(len(results), 2) self.assertIn(statement1, results) self.assertIn(statement2, results) @@ -276,6 +272,7 @@ def test_filter_contains_result(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, in_response_to__contains="Why are you counting?" ) self.assertEqual(len(results), 1) @@ -285,6 +282,7 @@ def test_filter_contains_no_result(self): self.adapter.update(self.statement1) results = self.adapter.filter( + Statement, in_response_to__contains="How do you do?" ) self.assertEqual(results, []) @@ -294,6 +292,7 @@ def test_filter_multiple_parameters(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, text="Testing...", in_response_to__contains="Why are you counting?" ) @@ -306,6 +305,7 @@ def test_filter_multiple_parameters_no_results(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, text="Test", in_response_to__contains="Not an existing response." ) @@ -322,7 +322,7 @@ def test_filter_no_parameters(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter() + results = self.adapter.filter(Statement) self.assertEqual(len(results), 2) @@ -336,6 +336,7 @@ def test_filter_returns_statement_with_multiple_responses(self): ) self.adapter.update(statement) response = self.adapter.filter( + Statement, in_response_to__contains="Thanks." ) @@ -357,7 +358,7 @@ def test_response_list_in_results(self): ] ) self.adapter.update(statement) - found = self.adapter.filter(text=statement.text) + found = self.adapter.filter(Statement, text=statement.text) self.assertEqual(len(found[0].in_response_to), 1) self.assertEqual(type(found[0].in_response_to[0]), Response) @@ -375,7 +376,7 @@ def test_order_by_text(self): self.adapter.update(statement_a) self.adapter.update(statement_b) - results = self.adapter.filter(order_by='text') + results = self.adapter.filter(Statement, order_by='text') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) @@ -399,7 +400,7 @@ def test_order_by_created_at(self): self.adapter.update(statement_a) self.adapter.update(statement_b) - results = self.adapter.filter(order_by='created_at') + results = self.adapter.filter(Statement, order_by='created_at') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) diff --git a/tests/storage_adapter_tests/test_mongo_adapter.py b/tests/storage_adapter_tests/test_mongo_adapter.py index c95f2b255..5f9e10d1f 100644 --- a/tests/storage_adapter_tests/test_mongo_adapter.py +++ b/tests/storage_adapter_tests/test_mongo_adapter.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest import SkipTest from chatterbot.storage import MongoDatabaseAdapter -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement class MongoAdapterTestCase(TestCase): @@ -90,9 +90,7 @@ def test_update_modifies_existing_statement(self): ) # Update the statement value - statement.add_response( - Response("New response") - ) + statement.in_response_to = Statement("New response") self.adapter.update(statement) # Check that the values have changed @@ -165,14 +163,10 @@ def test_getting_and_updating_statement(self): self.assertEqual(len(response.in_response_to), 1) self.assertEqual(response.in_response_to[0].occurrence, 2) - def test_deserialize_responses(self): - response_list = [ - {"text": "Test", "occurrence": 3}, - {"text": "Testing", "occurrence": 1}, - ] - results = self.adapter.deserialize_responses(response_list) + def test_deserialize_response(self): + result = self.adapter.deserialize_responses({'text': 'Test'}) - self.assertEqual(len(results), 2) + self.assertEqual(result.text, 'Test') def test_mongo_to_object(self): self.adapter.update( @@ -183,7 +177,9 @@ def test_mongo_to_object(self): ] ) ) - statement_data = self.adapter.statements.find_one({'text': 'Hello'}) + statement_data = self.adapter.database['statements'].find_one( + {'text': 'Hello'} + ) obj = self.adapter.mongo_to_object(statement_data) @@ -211,7 +207,7 @@ def test_remove(self): text = "Sometimes you have to run before you can walk." statement = Statement(text) self.adapter.update(statement) - self.adapter.remove(statement.text) + self.adapter.remove(statement) result = self.adapter.find(text) self.assertIsNone(result) @@ -223,8 +219,8 @@ def test_remove_response(self): in_response_to=[Response(text)] ) self.adapter.update(statement) - self.adapter.remove(statement.text) - results = self.adapter.filter(in_response_to__contains=text) + self.adapter.remove(statement) + results = self.adapter.filter(Statement, in_response_to__contains=text) self.assertEqual(results, []) @@ -270,14 +266,14 @@ def setUp(self): def test_filter_text_no_matches(self): self.adapter.update(self.statement1) - results = self.adapter.filter(text="Howdy") + results = self.adapter.filter(Statement, text="Howdy") self.assertEqual(len(results), 0) def test_filter_in_response_to_no_matches(self): self.adapter.update(self.statement1) - results = self.adapter.filter(in_response_to="Maybe") + results = self.adapter.filter(Statement, in_response_to="Maybe") self.assertEqual(len(results), 0) def test_filter_equal_results(self): @@ -292,7 +288,7 @@ def test_filter_equal_results(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter(in_response_to=[]) + results = self.adapter.filter(Statement, in_response_to=[]) self.assertEqual(len(results), 2) self.assertIn(statement1, results) self.assertIn(statement2, results) @@ -302,6 +298,7 @@ def test_filter_contains_result(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, in_response_to__contains="Why are you counting?" ) self.assertEqual(len(results), 1) @@ -311,6 +308,7 @@ def test_filter_contains_no_result(self): self.adapter.update(self.statement1) results = self.adapter.filter( + Statement, in_response_to__contains="How do you do?" ) self.assertEqual(results, []) @@ -320,6 +318,7 @@ def test_filter_multiple_parameters(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, text="Testing...", in_response_to__contains="Why are you counting?" ) @@ -332,6 +331,7 @@ def test_filter_multiple_parameters_no_results(self): self.adapter.update(self.statement2) results = self.adapter.filter( + Statement, text="Test", in_response_to__contains="Not an existing response." ) @@ -348,7 +348,7 @@ def test_filter_no_parameters(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter() + results = self.adapter.filter(Statement) self.assertEqual(len(results), 2) @@ -362,6 +362,7 @@ def test_filter_returns_statement_with_multiple_responses(self): ) self.adapter.update(statement) response = self.adapter.filter( + Statement, in_response_to__contains="Thanks." ) @@ -378,12 +379,10 @@ def test_response_list_in_results(self): """ statement = Statement( "The first is to help yourself, the second is to help others.", - in_response_to=[ - Response("Why do people have two hands?") - ] + in_response_to=Statement("Why do people have two hands?") ) self.adapter.update(statement) - found = self.adapter.filter(text=statement.text) + found = self.adapter.filter(Statement, text=statement.text) self.assertEqual(len(found[0].in_response_to), 1) self.assertEqual(type(found[0].in_response_to[0]), Response) @@ -401,7 +400,7 @@ def test_order_by_text(self): self.adapter.update(statement_a) self.adapter.update(statement_b) - results = self.adapter.filter(order_by='text') + results = self.adapter.filter(Statement, order_by='text') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) @@ -425,7 +424,7 @@ def test_order_by_created_at(self): self.adapter.update(statement_a) self.adapter.update(statement_b) - results = self.adapter.filter(order_by='created_at') + results = self.adapter.filter(Statement, order_by='created_at') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) diff --git a/tests/storage_adapter_tests/test_storage_adapter.py b/tests/storage_adapter_tests/test_storage_adapter.py index 35c077d05..7db2ba021 100644 --- a/tests/storage_adapter_tests/test_storage_adapter.py +++ b/tests/storage_adapter_tests/test_storage_adapter.py @@ -1,6 +1,7 @@ from unittest import TestCase from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement + class StorageAdapterTestCase(TestCase): """ @@ -24,15 +25,15 @@ def test_find(self): def test_filter(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): - self.adapter.filter() + self.adapter.filter(None) def test_remove(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): - self.adapter.remove('') + self.adapter.remove(None) def test_update(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): - self.adapter.update('') + self.adapter.update(None) def test_get_random(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 88bb886ff..b96fcedb1 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from .base_case import ChatBotTestCase -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement class ChatterBotResponseTestCase(ChatBotTestCase): @@ -8,11 +8,7 @@ class ChatterBotResponseTestCase(ChatBotTestCase): def setUp(self): super(ChatterBotResponseTestCase, self).setUp() - response_list = [ - Response('Hi') - ] - - self.test_statement = Statement('Hello', in_response_to=response_list) + self.test_statement = Statement('Hello', in_response_to=Statement('Hi')) def test_empty_database(self): """ @@ -42,11 +38,9 @@ def test_statement_added_to_recent_response_list(self): """ statement_text = 'Wow!' response = self.chatbot.get_response(statement_text) - session = self.chatbot.conversation_sessions.get( - self.chatbot.default_session.id_string - ) + conversation = self.chatbot.storage.filter(self.chatbot.storage.Conversation)[0] - self.assertIn(statement_text, session.conversation[0]) + self.assertIn(statement_text, conversation.statements.all()) self.assertEqual(response, statement_text) def test_response_known(self): @@ -63,8 +57,7 @@ def test_response_format(self): statement_object = self.chatbot.storage.find(response.text) self.assertEqual(response, self.test_statement.text) - self.assertIsLength(statement_object.in_response_to, 1) - self.assertIn('Hi', statement_object.in_response_to) + self.assertEqual('Hi', statement_object.in_response_to) def test_second_response_format(self): self.chatbot.storage.update(self.test_statement) @@ -78,8 +71,7 @@ def test_second_response_format(self): self.assertIsNotNone(self.chatbot.storage.find('How are you?')) self.assertEqual(second_response, self.test_statement.text) - self.assertIsLength(statement.in_response_to, 1) - self.assertIn('Hi', statement.in_response_to) + self.assertEqual('Hi', statement.in_response_to) def test_get_response_unicode(self): """ @@ -108,9 +100,10 @@ def test_response_extra_data(self): def test_generate_response(self): statement = Statement('Many insects adopt a tripedal gait for rapid yet stable walking.') + conversation = self.chatbot.get_or_create_default_conversation() input_statement, response = self.chatbot.generate_response( statement, - self.chatbot.default_session.id + conversation.id ) self.assertEqual(input_statement, statement) diff --git a/tests/test_context.py b/tests/test_context.py deleted file mode 100644 index 7fa6f411e..000000000 --- a/tests/test_context.py +++ /dev/null @@ -1,21 +0,0 @@ -from .base_case import ChatBotTestCase - - -class AdapterTests(ChatBotTestCase): - - def test_modify_chatbot(self): - """ - When one adapter modifies its chatbot instance, - the change should be the same in all other adapters. - """ - session = self.chatbot.input.chatbot.conversation_sessions.new() - self.chatbot.input.chatbot.conversation_sessions.update( - session.id_string, - ('A', 'B', ) - ) - - session = self.chatbot.output.chatbot.conversation_sessions.get( - session.id_string - ) - - self.assertIn(('A', 'B', ), session.conversation) diff --git a/tests/test_conversations.py b/tests/test_conversations.py new file mode 100644 index 000000000..8b935f44e --- /dev/null +++ b/tests/test_conversations.py @@ -0,0 +1,76 @@ +from unittest import TestCase +from chatterbot.conversation import Statement +from chatterbot.conversation.session import Conversation, ConversationManager +from .base_case import ChatBotTestCase + + +class ConversationTestCase(TestCase): + + def setUp(self): + super(ConversationTestCase, self).setUp() + self.conversation = Conversation() + + def test_id(self): + self.assertEqual(str(self.conversation.uuid), self.conversation.id) + + def test_no_last_response_statement(self): + self.assertIsNone(self.conversation.get_last_response_statement()) + + def test_get_last_response_statement(self): + """ + Make sure that the get last statement method + returns the last statement that was issued. + """ + self.conversation.statements.add(Statement('Test statement 1')) + self.conversation.statements.add(Statement('Test response 1')) + self.conversation.statements.add(Statement('Test statement 2')) + self.conversation.statements.add(Statement('Test response 2')) + + last_statement = self.conversation.get_last_response_statement() + self.assertEqual(last_statement, 'Test response 2') + + def test_no_last_input_statement(self): + self.assertIsNone(self.conversation.get_last_input_statement()) + + def test_get_last_input_statement(self): + """ + Make sure that the get last statement method + returns the last statement that was issued. + """ + self.conversation.statements.add(Statement('Test statement 1')) + self.conversation.statements.add(Statement('Test response 1')) + self.conversation.statements.add(Statement('Test statement 2')) + self.conversation.statements.add(Statement('Test response 2')) + + last_statement = self.conversation.get_last_input_statement() + self.assertEqual(last_statement, 'Test statement 2') + + +class ConversationManagerTestCase(ChatBotTestCase): + + def setUp(self): + super(ConversationManagerTestCase, self).setUp() + self.manager = ConversationManager(self.chatbot.storage) + + def test_new(self): + conversation = self.manager.create() + + self.assertTrue(isinstance(conversation, Conversation)) + self.assertEqual(conversation.id, self.manager.get(conversation.id).id) + + def test_get(self): + conversation = self.manager.create() + returned_conversation = self.manager.get(conversation.id) + + self.assertEqual(conversation.id, returned_conversation.id) + + def test_get_invalid_id(self): + returned_conversation = self.manager.get('--invalid--') + + self.assertIsNone(returned_conversation) + + def test_get_invalid_id_with_deafult(self): + returned_conversation = self.manager.get('--invalid--', 'default_value') + + self.assertEqual(returned_conversation, 'default_value') + diff --git a/tests/test_queries.py b/tests/test_queries.py index ff1cd9426..f3cd40524 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -25,9 +25,8 @@ def test_statement_response_list_contains(self): query = self.query.statement_response_list_contains('Hey') self.assertIn('in_response_to', query.value()) - self.assertIn('$elemMatch', query.value()['in_response_to']) - self.assertIn('text', query.value()['in_response_to']['$elemMatch']) - self.assertEqual('Hey', query.value()['in_response_to']['$elemMatch']['text']) + self.assertIn('text', query.value()['in_response_to']) + self.assertEqual('Hey', query.value()['in_response_to']['text']) def test_statement_response_list_equals(self): query = self.query.statement_response_list_equals([]) diff --git a/tests/test_queues.py b/tests/test_queues.py deleted file mode 100644 index 7dd30caca..000000000 --- a/tests/test_queues.py +++ /dev/null @@ -1,80 +0,0 @@ -from unittest import TestCase -from chatterbot import queues - - -class FixedSizeQueueTests(TestCase): - - def setUp(self): - self.queue = queues.FixedSizeQueue(maxsize=2) - - def test_append(self): - self.queue.append(0) - self.assertIn(0, self.queue) - - def test_contains(self): - self.queue.queue.append(0) - self.assertIn(0, self.queue) - - def test_empty(self): - self.assertTrue(self.queue.empty()) - - def test_not_empty(self): - self.queue.append(0) - self.assertFalse(self.queue.empty()) - - def test_maxsize(self): - self.queue.append(0) - self.queue.append(1) - self.queue.append(2) - - self.assertNotIn(0, self.queue) - self.assertIn(1, self.queue) - self.assertIn(2, self.queue) - - def test_peek_empty_queue(self): - self.assertIsNone(self.queue.peek()) - - def test_peek(self): - self.queue.append(4) - self.queue.append(5) - self.queue.append(6) - - self.assertEqual(self.queue.peek(), 6) - - -class ResponseQueueTests(TestCase): - """ - The response view is a version of the FixedSizeQueue with - additional utility methods to help manage the conversation. - """ - - def setUp(self): - self.queue = queues.ResponseQueue(maxsize=2) - - def test_no_last_response_statement(self): - self.assertIsNone(self.queue.get_last_response_statement()) - - def test_get_last_response_statement(self): - """ - Make sure that the get last statement method - returns the last statement that was issued. - """ - self.queue.append(('Test statement 1', 'Test response 1', )) - self.queue.append(('Test statement 2', 'Test response 2', )) - - last_statement = self.queue.get_last_response_statement() - self.assertEqual(last_statement, 'Test response 2') - - def test_no_last_input_statement(self): - self.assertIsNone(self.queue.get_last_input_statement()) - - def test_get_last_input_statement(self): - """ - Make sure that the get last statement method - returns the last statement that was issued. - """ - self.queue.append(('Test statement 1', 'Test response 1', )) - self.queue.append(('Test statement 2', 'Test response 2', )) - - last_statement = self.queue.get_last_input_statement() - self.assertEqual(last_statement, 'Test statement 2') diff --git a/tests/test_response_selection.py b/tests/test_response_selection.py index 83a98ba88..41417d133 100644 --- a/tests/test_response_selection.py +++ b/tests/test_response_selection.py @@ -1,16 +1,16 @@ from unittest import TestCase from chatterbot import response_selection -from chatterbot.conversation import Statement, Response +from chatterbot.conversation import Statement class ResponseSelectionTests(TestCase): def test_get_most_frequent_response(self): statement_list = [ - Statement('What... is your quest?', in_response_to=[Response('Hello', occurrence=2)]), - Statement('This is a phone.', in_response_to=[Response('Hello', occurrence=4)]), - Statement('A what?', in_response_to=[Response('Hello', occurrence=2)]), - Statement('A phone.', in_response_to=[Response('Hello', occurrence=1)]) + Statement('What... is your quest?', in_response_to=Statement('Hello', occurrence=2)), + Statement('This is a phone.', in_response_to=Statement('Hello', occurrence=4)), + Statement('A what?', in_response_to=Statement('Hello', occurrence=2)), + Statement('A phone.', in_response_to=Statement('Hello', occurrence=1)) ] output = response_selection.get_most_frequent_response( diff --git a/tests/test_sessions.py b/tests/test_sessions.py deleted file mode 100644 index 741319851..000000000 --- a/tests/test_sessions.py +++ /dev/null @@ -1,50 +0,0 @@ -from unittest import TestCase -from chatterbot.conversation.session import Session, ConversationSessionManager - - -class SessionTestCase(TestCase): - - def test_id_string(self): - session = Session() - self.assertEqual(str(session.uuid), session.id_string) - - -class ConversationSessionManagerTestCase(TestCase): - - def setUp(self): - super(ConversationSessionManagerTestCase, self).setUp() - self.manager = ConversationSessionManager() - - def test_new(self): - session = self.manager.new() - - self.assertTrue(isinstance(session, Session)) - self.assertIn(session.id_string, self.manager.sessions) - self.assertEqual(session, self.manager.sessions[session.id_string]) - - def test_get(self): - session = self.manager.new() - returned_session = self.manager.get(session.id_string) - - self.assertEqual(session.id_string, returned_session.id_string) - - def test_get_invalid_id(self): - returned_session = self.manager.get('--invalid--') - - self.assertIsNone(returned_session) - - def test_get_invalid_id_with_deafult(self): - returned_session = self.manager.get('--invalid--', 'default_value') - - self.assertEqual(returned_session, 'default_value') - - def test_update(self): - session = self.manager.new() - self.manager.update(session.id_string, ('A', 'B', )) - - session_ids = list(self.manager.sessions.keys()) - session_id = session_ids[0] - - self.assertEqual(len(session_ids), 1) - self.assertEqual(len(self.manager.get(session_id).conversation), 1) - self.assertEqual(('A', 'B', ), self.manager.get(session_id).conversation[0]) diff --git a/tests/training_tests/test_list_training.py b/tests/training_tests/test_list_training.py index 8e1479470..66995cf50 100644 --- a/tests/training_tests/test_list_training.py +++ b/tests/training_tests/test_list_training.py @@ -44,9 +44,10 @@ def test_training_increments_occurrence_count(self): self.chatbot.train(conversation) statements = self.chatbot.storage.filter( + self.chatbot.storage.Statement, in_response_to__contains="Do you like my hat?" ) - response = statements[0].in_response_to[0] + response = statements[0].in_response_to self.assertEqual(response.occurrence, 2) @@ -76,22 +77,13 @@ def test_database_has_correct_format(self): # There should be a total of 9 statements in the database after training self.assertEqual(self.chatbot.storage.count(), 9) - # The first statement should be in response to another statement - self.assertEqual( - len(self.chatbot.storage.find(conversation[0]).in_response_to), - 0 - ) - - # The second statement should have one response - self.assertEqual( - len(self.chatbot.storage.find(conversation[1]).in_response_to), - 1 - ) + # The first statement should not be in response to another statement + self.assertIsNone(self.chatbot.storage.find(conversation[0]).in_response_to) # The second statement should be in response to the first statement - self.assertIn( + self.assertEqual( conversation[0], - self.chatbot.storage.find(conversation[1]).in_response_to, + self.chatbot.storage.find(conversation[1]).in_response_to ) def test_training_with_unicode_characters(self): diff --git a/tests/training_tests/test_twitter_trainer.py b/tests/training_tests/test_twitter_trainer.py index 9f265da39..36049c1f5 100644 --- a/tests/training_tests/test_twitter_trainer.py +++ b/tests/training_tests/test_twitter_trainer.py @@ -79,5 +79,5 @@ def test_get_statements(self): def test_train(self): self.trainer.train() - statement_created = self.trainer.storage.filter() + statement_created = self.trainer.storage.filter(self.trainer.storage.Statement) self.assertTrue(len(statement_created)) diff --git a/tests_django/integration_tests/test_statement_integration.py b/tests_django/integration_tests/test_statement_integration.py index 8b29d5a4c..693291e64 100644 --- a/tests_django/integration_tests/test_statement_integration.py +++ b/tests_django/integration_tests/test_statement_integration.py @@ -1,9 +1,7 @@ from django.test import TestCase from django.utils import timezone from chatterbot.conversation import Statement as StatementObject -from chatterbot.conversation import Response as ResponseObject from chatterbot.ext.django_chatterbot.models import Statement as StatementModel -from chatterbot.ext.django_chatterbot.models import Response as ResponseModel class StatementIntegrationTestCase(TestCase): @@ -40,36 +38,6 @@ def test_add_extra_data(self): self.object.add_extra_data('key', 'value') self.model.add_extra_data('key', 'value') - def test_add_response(self): - self.assertTrue(hasattr(self.object, 'add_response')) - self.assertTrue(hasattr(self.model, 'add_response')) - - def test_remove_response(self): - self.object.add_response(ResponseObject('Hello')) - model_response_statement = StatementModel.objects.create(text='Hello') - self.model.save() - self.model.in_response.create(statement=self.model, response=model_response_statement) - - object_removed = self.object.remove_response('Hello') - model_removed = self.model.remove_response('Hello') - - self.assertTrue(object_removed) - self.assertTrue(model_removed) - - def test_get_response_count(self): - self.object.add_response(ResponseObject('Hello', occurrence=2)) - model_response_statement = StatementModel.objects.create(text='Hello') - self.model.save() - self.model.in_response.create( - statement=self.model, response=model_response_statement, occurrence=2 - ) - - object_count = self.object.get_response_count(StatementObject(text='Hello')) - model_count = self.model.get_response_count(StatementModel(text='Hello')) - - self.assertEqual(object_count, 2) - self.assertEqual(model_count, 2) - def test_serialize(self): object_data = self.object.serialize() model_data = self.model.serialize() @@ -83,25 +51,3 @@ def test_serialize(self): def test_response_statement_cache(self): self.assertTrue(hasattr(self.object, 'response_statement_cache')) self.assertTrue(hasattr(self.model, 'response_statement_cache')) - - -class ResponseIntegrationTestCase(TestCase): - - """ - Test case to make sure that the Django Response model - and ChatterBot Response object have a common interface. - """ - - def setUp(self): - super(ResponseIntegrationTestCase, self).setUp() - date_created = timezone.now() - statement_object = StatementObject(text='_', created_at=date_created) - statement_model = StatementModel.objects.create(text='_', created_at=date_created) - self.object = ResponseObject(statement_object.text) - self.model = ResponseModel(statement=statement_model, response=statement_model) - - def test_serialize(self): - object_data = self.object.serialize() - model_data = self.model.serialize() - - self.assertEqual(object_data, model_data) diff --git a/tests_django/test_api.py b/tests_django/test_api.py index 10a2c187b..6f15d0af7 100644 --- a/tests_django/test_api.py +++ b/tests_django/test_api.py @@ -27,7 +27,7 @@ def test_post(self): self.assertIn('text', str(response.content)) self.assertIn('in_response_to', str(response.content)) - def test_response_is_in_json(self): + def test_response_is_json(self): """ Test Response is in JSON """ diff --git a/tests_django/test_api_view.py b/tests_django/test_api_view.py index fc8741f14..27ed94798 100644 --- a/tests_django/test_api_view.py +++ b/tests_django/test_api_view.py @@ -10,18 +10,14 @@ def setUp(self): super(ApiIntegrationTestCase, self).setUp() self.api_url = reverse('chatterbot') - # Clear the response queue before tests - ChatterBotView.chatterbot.conversation_sessions.get( - ChatterBotView.chatterbot.default_session.id_string - ).conversation.flush() + # Clear the database before tests + ChatterBotView.chatterbot.storage.drop() def tearDown(self): super(ApiIntegrationTestCase, self).tearDown() - # Clear the response queue after tests - ChatterBotView.chatterbot.conversation_sessions.get( - ChatterBotView.chatterbot.default_session.id_string - ).conversation.flush() + # Clear the database after tests + ChatterBotView.chatterbot.storage.drop() def _get_json(self, response): from django.utils.encoding import force_text @@ -35,18 +31,24 @@ def test_get_conversation_empty(self): self.assertEqual(len(data['conversation']), 0) def test_get_conversation(self): - response = self.client.post( + self.client.post( self.api_url, data=json.dumps({'text': 'How are you?'}), content_type='application/json', format='json' ) + self.client.post( + self.api_url, + data=json.dumps({'text': 'I am good'}), + content_type='application/json', + format='json' + ) + response = self.client.get(self.api_url) data = self._get_json(response) self.assertIn('conversation', data) - self.assertEqual(len(data['conversation']), 1) - self.assertEqual(len(data['conversation'][0]), 2) - self.assertIn('text', data['conversation'][0][0]) - self.assertIn('text', data['conversation'][0][1]) + self.assertEqual(len(data['conversation']), 2) + self.assertIn('text', data['conversation'][0]) + self.assertIn('text', data['conversation'][1]) diff --git a/tests_django/test_django_adapter.py b/tests_django/test_django_adapter.py index 496b9592f..f798afb14 100644 --- a/tests_django/test_django_adapter.py +++ b/tests_django/test_django_adapter.py @@ -2,7 +2,6 @@ from django.test import TestCase from chatterbot.storage import DjangoStorageAdapter from chatterbot.ext.django_chatterbot.models import Statement as StatementModel -from chatterbot.ext.django_chatterbot.models import Response as ResponseModel class DjangoAdapterTestCase(TestCase): @@ -89,16 +88,21 @@ def test_get_random_returns_statement(self): self.assertEqual(random_statement.text, statement.text) def test_find_returns_nested_responses(self): - statement = StatementModel.objects.create(text="Do you like this?") - statement.add_response(StatementModel(text="Yes")) - statement.add_response(StatementModel(text="No")) + question = StatementModel.objects.create(text='Do you like this?') - self.adapter.update(statement) + yes = StatementModel(text='Yes') + yes.in_response_to = question + yes.save() - result = self.adapter.find(statement.text) + no = StatementModel(text='No') + no.in_response_to = question + no.save() + + result = self.adapter.find(question.text) + responses = result.responses() - self.assertTrue(result.in_response_to.filter(response__text="Yes").exists()) - self.assertTrue(result.in_response_to.filter(response__text="No").exists()) + self.assertTrue(responses.in_response_to.filter(text='Yes').exists()) + self.assertTrue(responses.in_response_to.filter(text='No').exists()) def test_multiple_responses_added_on_update(self): statement = StatementModel.objects.create(text="You are welcome.") @@ -154,7 +158,7 @@ def test_remove_response(self): StatementModel(text=text) ) self.adapter.remove(statement.text) - results = self.adapter.filter(in_response_to__contains=text) + results = self.adapter.filter(StatementModel, in_response_to__contains=text) self.assertEqual(results.count(), 0) @@ -196,14 +200,14 @@ def setUp(self): def test_filter_text_no_matches(self): self.adapter.update(self.statement1) - results = self.adapter.filter(text="Howdy") + results = self.adapter.filter(StatementModel, text="Howdy") self.assertEqual(len(results), 0) def test_filter_in_response_to_no_matches(self): self.adapter.update(self.statement1) - results = self.adapter.filter(in_response_to="Maybe") + results = self.adapter.filter(StatementModel, in_response_to="Maybe") self.assertEqual(len(results), 0) def test_filter_equal_results(self): @@ -213,7 +217,7 @@ def test_filter_equal_results(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter(in_response_to=[]) + results = self.adapter.filter(StatementModel, in_response_to=[]) self.assertEqual(results.count(), 2) self.assertTrue(results.filter(text=statement1.text).exists()) @@ -224,6 +228,7 @@ def test_filter_contains_result(self): self.adapter.update(self.statement2) results = self.adapter.filter( + StatementModel, in_response_to__contains="Why are you counting?" ) self.assertEqual(results.count(), 1) @@ -233,6 +238,7 @@ def test_filter_contains_no_result(self): self.adapter.update(self.statement1) results = self.adapter.filter( + StatementModel, in_response_to__contains="How do you do?" ) self.assertEqual(results.count(), 0) @@ -242,6 +248,7 @@ def test_filter_multiple_parameters(self): self.adapter.update(self.statement2) results = self.adapter.filter( + StatementModel, text="Testing...", in_response_to__contains="Why are you counting?" ) @@ -254,6 +261,7 @@ def test_filter_multiple_parameters_no_results(self): self.adapter.update(self.statement2) results = self.adapter.filter( + StatementModel, text="Test", in_response_to__contains="Not an existing response." ) @@ -270,7 +278,7 @@ def test_filter_no_parameters(self): self.adapter.update(statement1) self.adapter.update(statement2) - results = self.adapter.filter() + results = self.adapter.filter(StatementModel) self.assertEqual(len(results), 2) @@ -282,6 +290,7 @@ def test_filter_returns_statement_with_multiple_responses(self): self.adapter.update(statement) response = self.adapter.filter( + StatementModel, in_response_to__contains="Thanks." ) @@ -302,10 +311,10 @@ def test_response_list_in_results(self): self.adapter.update(statement) - found = self.adapter.filter(text=statement.text) + found = self.adapter.filter(StatementModel, text=statement.text) self.assertEqual(len(found[0].in_response_to), 1) - self.assertEqual(type(found[0].in_response_to[0]), ResponseModel) + self.assertEqual(type(found[0].in_response_to[0]), StatementModel) def test_confidence(self): """ @@ -333,7 +342,7 @@ def test_order_by_text(self): statement_a = StatementModel.objects.create(text='A is the first letter of the alphabet.') statement_b = StatementModel.objects.create(text='B is the second letter of the alphabet.') - results = self.adapter.filter(order_by='text') + results = self.adapter.filter(StatementModel, order_by='text') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) @@ -345,7 +354,7 @@ def test_order_by_created_at(self): statement_a = StatementModel.objects.create(text='A is the first letter of the alphabet.') statement_b = StatementModel.objects.create(text='B is the second letter of the alphabet.') - results = self.adapter.filter(order_by='created_at') + results = self.adapter.filter(StatementModel, order_by='created_at') self.assertEqual(len(results), 2) self.assertEqual(results[0], statement_a) diff --git a/tests_django/test_settings.py b/tests_django/test_settings.py index a078b3a63..31774b94f 100644 --- a/tests_django/test_settings.py +++ b/tests_django/test_settings.py @@ -3,6 +3,8 @@ """ import os +DEBUG = True + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) SECRET_KEY = 'fake-key' diff --git a/tests_django/test_views.py b/tests_django/test_views.py index 2009a74bb..e385953fa 100644 --- a/tests_django/test_views.py +++ b/tests_django/test_views.py @@ -5,8 +5,8 @@ class MockResponse(object): - def __init__(self, id_string): - self.session = {'chat_session_id': id_string} + def __init__(self, pk): + self.session = {'chat_session_id': pk} class ViewTestCase(TestCase): @@ -30,21 +30,21 @@ def test_validate_invalid_text(self): }) def test_get_chat_session(self): - session = self.view.chatterbot.conversation_sessions.new() - mock_response = MockResponse(session.id_string) + session = self.view.chatterbot.conversation_sessions.create() + mock_response = MockResponse(session.id) get_session = self.view.get_chat_session(mock_response) - self.assertEqual(session.id_string, get_session.id_string) + self.assertEqual(session.id, get_session.id) def test_get_chat_session_invalid(self): - mock_response = MockResponse('--invalid--') + mock_response = MockResponse(0) session = self.view.get_chat_session(mock_response) - self.assertNotEqual(session.id_string, 'test-session-id') + self.assertNotEqual(session.id, 'test-session-id') def test_get_chat_session_no_session(self): mock_response = MockResponse(None) mock_response.session = {} session = self.view.get_chat_session(mock_response) - self.assertNotEqual(session.id_string, 'test-session-id') + self.assertNotEqual(session.id, 'test-session-id')