Skip to content

Commit

Permalink
Update tests for variable changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Jan 27, 2017
1 parent 7c712a4 commit 72457f9
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 18 deletions.
13 changes: 11 additions & 2 deletions chatterbot/conversation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def all(self):
"""
Return all statements in the conversation.
"""
return self.storage.filter(conversation__id=self.conversation_id)
return self.storage.filter(
conversation__id=self.conversation_id,
order_by='created_at'
)

def add(self, statement):
"""
Expand All @@ -25,6 +28,12 @@ def add(self, statement):
statement.conversation_id = self.conversation_id
self.storage.update(statement)

def count(self):
return len(self.all())

def exists(self):
return self.count() > 0


class Session(object):
"""
Expand Down Expand Up @@ -52,7 +61,7 @@ def get_last_response_statement(self):
statements = self.statements.all()
if statements:
# Return the latest output statement (This should be ordering them by date to get the latest)
return statements[-1]
return statements[1]
return None


Expand Down
7 changes: 2 additions & 5 deletions chatterbot/ext/django_chatterbot/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ class ChatterBotView(ChatterBotViewMixin, View):
"""

def _serialize_conversation(self, session):
if session.conversation.empty():
return []

conversation = []

for statement, response in session.conversation:
conversation.append([statement.serialize(), response.serialize()])
for statement in session.statements:
conversation.append(statement.serialize())

return conversation

Expand Down
6 changes: 3 additions & 3 deletions chatterbot/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def filter_selection(self, chatterbot, session_id):

session = chatterbot.conversation_sessions.get(session_id)

if session.conversation.empty():
if not session.statements.exists():
return chatterbot.storage.base_query

text_of_recent_responses = []

for statement, response in session.conversation:
text_of_recent_responses.append(response.text)
for statement in session.statements:
text_of_recent_responses.append(statement.text)

query = chatterbot.storage.base_query.statement_text_not_in(
text_of_recent_responses
Expand Down
2 changes: 1 addition & 1 deletion tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_statement_added_to_recent_response_list(self):
self.chatbot.default_session.id_string
)

self.assertIn(statement_text, session.conversation[0])
self.assertIn(statement_text, session.statements.all())
self.assertEqual(response, statement_text)

def test_response_known(self):
Expand Down
8 changes: 5 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_case import ChatBotTestCase
from chatterbot.conversation import Statement


class AdapterTests(ChatBotTestCase):
Expand All @@ -8,14 +9,15 @@ def test_modify_chatbot(self):
When one adapter modifies its chatbot instance,
the change should be the same in all other adapters.
"""

session = self.chatbot.input.chatbot.conversation_sessions.new()
self.chatbot.input.chatbot.conversation_sessions.update(
session.id_string,
('A', 'B', )
)
Statement('A'), Statement('B')

session = self.chatbot.output.chatbot.conversation_sessions.get(
session.id_string
)

self.assertIn(('A', 'B', ), session.conversation)
self.assertEqual(Statement('A'), session.statements.all()[0])
self.assertEqual(Statement('B'), session.statements.all()[1])
10 changes: 6 additions & 4 deletions tests/test_sessions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chatterbot.conversation import Conversation
from chatterbot.conversation import Statement, Conversation
from chatterbot.conversation.session import ConversationSessionManager
from .base_case import ChatBotTestCase

Expand Down Expand Up @@ -41,11 +41,13 @@ def test_get_invalid_id_with_deafult(self):

def test_update(self):
session = self.manager.new()
self.manager.update(session.id_string, ('A', 'B', ))
self.manager.update(session.id_string, (Statement('A'), Statement('B'), ))

session_ids = list(self.manager.sessions.keys())
session_id = session_ids[0]

self.assertEqual(len(session_ids), 1)
self.assertEqual(len(self.manager.get(session_id).conversation), 1)
self.assertEqual(('A', 'B', ), self.manager.get(session_id).conversation[0])
self.assertEqual(self.manager.get(session_id).statements.count(), 2)
self.assertEqual(Statement('A'), self.manager.get(session_id).statements.all()[0])
self.assertEqual(Statement('B'), self.manager.get(session_id).statements.all()[1])

0 comments on commit 72457f9

Please sign in to comment.