diff --git a/chatterbot/conversation/session.py b/chatterbot/conversation/session.py index b98d7e208..cdb3751ec 100644 --- a/chatterbot/conversation/session.py +++ b/chatterbot/conversation/session.py @@ -16,7 +16,10 @@ def all(self): """ Return all statements in the conversation. """ - return self.storage.filter(conversation__id=self.conversation_id) + return self.storage.filter( + conversation__id=self.conversation_id, + order_by='created_at' + ) def add(self, statement): """ @@ -25,6 +28,12 @@ def add(self, statement): statement.conversation_id = self.conversation_id self.storage.update(statement) + def count(self): + return len(self.all()) + + def exists(self): + return self.count() > 0 + class Session(object): """ @@ -52,7 +61,7 @@ def get_last_response_statement(self): statements = self.statements.all() if statements: # Return the latest output statement (This should be ordering them by date to get the latest) - return statements[-1] + return statements[1] return None diff --git a/chatterbot/ext/django_chatterbot/views.py b/chatterbot/ext/django_chatterbot/views.py index 7f3ab81ba..bb044611c 100644 --- a/chatterbot/ext/django_chatterbot/views.py +++ b/chatterbot/ext/django_chatterbot/views.py @@ -44,13 +44,10 @@ class ChatterBotView(ChatterBotViewMixin, View): """ def _serialize_conversation(self, session): - if session.conversation.empty(): - return [] - conversation = [] - for statement, response in session.conversation: - conversation.append([statement.serialize(), response.serialize()]) + for statement in session.statements: + conversation.append(statement.serialize()) return conversation diff --git a/chatterbot/filters.py b/chatterbot/filters.py index 1b45574d6..ae41fcbea 100644 --- a/chatterbot/filters.py +++ b/chatterbot/filters.py @@ -28,13 +28,13 @@ def filter_selection(self, chatterbot, session_id): session = chatterbot.conversation_sessions.get(session_id) - if session.conversation.empty(): + if not session.statements.exists(): return chatterbot.storage.base_query text_of_recent_responses = [] - for statement, response in session.conversation: - text_of_recent_responses.append(response.text) + for statement in session.statements: + text_of_recent_responses.append(statement.text) query = chatterbot.storage.base_query.statement_text_not_in( text_of_recent_responses diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 30034fb91..af082ec98 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -46,7 +46,7 @@ def test_statement_added_to_recent_response_list(self): self.chatbot.default_session.id_string ) - self.assertIn(statement_text, session.conversation[0]) + self.assertIn(statement_text, session.statements.all()) self.assertEqual(response, statement_text) def test_response_known(self): diff --git a/tests/test_context.py b/tests/test_context.py index 7fa6f411e..5b1bd0fb3 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,4 +1,5 @@ from .base_case import ChatBotTestCase +from chatterbot.conversation import Statement class AdapterTests(ChatBotTestCase): @@ -8,14 +9,15 @@ def test_modify_chatbot(self): When one adapter modifies its chatbot instance, the change should be the same in all other adapters. """ + session = self.chatbot.input.chatbot.conversation_sessions.new() self.chatbot.input.chatbot.conversation_sessions.update( session.id_string, - ('A', 'B', ) - ) + Statement('A'), Statement('B') session = self.chatbot.output.chatbot.conversation_sessions.get( session.id_string ) - self.assertIn(('A', 'B', ), session.conversation) + self.assertEqual(Statement('A'), session.statements.all()[0]) + self.assertEqual(Statement('B'), session.statements.all()[1]) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index fd3db2a36..5dcccf09b 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -1,4 +1,4 @@ -from chatterbot.conversation import Conversation +from chatterbot.conversation import Statement, Conversation from chatterbot.conversation.session import ConversationSessionManager from .base_case import ChatBotTestCase @@ -41,11 +41,13 @@ def test_get_invalid_id_with_deafult(self): def test_update(self): session = self.manager.new() - self.manager.update(session.id_string, ('A', 'B', )) + self.manager.update(session.id_string, (Statement('A'), Statement('B'), )) session_ids = list(self.manager.sessions.keys()) session_id = session_ids[0] self.assertEqual(len(session_ids), 1) - self.assertEqual(len(self.manager.get(session_id).conversation), 1) - self.assertEqual(('A', 'B', ), self.manager.get(session_id).conversation[0]) + self.assertEqual(self.manager.get(session_id).statements.count(), 2) + self.assertEqual(Statement('A'), self.manager.get(session_id).statements.all()[0]) + self.assertEqual(Statement('B'), self.manager.get(session_id).statements.all()[1]) +