diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 14d0da9eb..a083cafed 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -3,6 +3,7 @@ from .adapters.input import InputAdapter from .adapters.output import OutputAdapter from .conversation import Statement +from .utils.queues import ResponseQueue from .utils.module_loading import import_module @@ -31,7 +32,8 @@ def __init__(self, name, **kwargs): "io_adapter_pairs" ) - self.recent_statements = [] + # The last 10 statement inputs and outputs + self.recent_statements = ResponseQueue(maxsize=10) # The storage adapter must be an instance of StorageAdapter self.validate_adapter_class(storage_adapter, StorageAdapter) @@ -98,8 +100,10 @@ def get_last_statement(self): """ Return the last statement that was received. """ - if self.recent_statements: - return self.recent_statements[-1] + previous_interaction = self.recent_statements[-1] + if previous_interaction: + input_statement, output_statement = previous_interaction + return output_statement return None def get_response(self, input_item): @@ -124,7 +128,9 @@ def get_response(self, input_item): # Update the database after selecting a response self.storage.update(input_statement) - self.recent_statements.append(response) + self.recent_statements.append( + (input_statement, response, ) + ) # Process the response output with the output adapter return self.output.process_response(response) diff --git a/chatterbot/utils/queues.py b/chatterbot/utils/queues.py new file mode 100644 index 000000000..e15fc6b96 --- /dev/null +++ b/chatterbot/utils/queues.py @@ -0,0 +1,30 @@ +class ResponseQueue(object): + """ + This is a data structure like a queue. + Only a fixed number of items can be added. + Once the maximum is reached, when a new item + is added the oldest item in the queue will + be removed. + """ + + def __init__(self, maxsize=10): + self.maxsize = maxsize + self.queue = [] + + def append(self, item): + if len(self.queue) == self.maxsize: + # Remove an element from the top of the list + self.queue.pop(0) + + self.queue.append(item) + + def __getitem__(self, index): + if self.queue: + return self.queue[index] + return None + + def __contains__(self, item): + """ + Check if an element is in this queue. + """ + return item in self.queue diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 4abecd5d9..7a82bbab1 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -10,17 +10,14 @@ def test_get_last_statement(self): returns the last statement that was issued. """ self.chatbot.recent_statements.append( - Statement("Test statement 1") + (Statement("Test statement 1"), Statement("Test response 1"), ) ) self.chatbot.recent_statements.append( - Statement("Test statement 2") - ) - self.chatbot.recent_statements.append( - Statement("Test statement 3") + (Statement("Test statement 2"), Statement("Test response 2"), ) ) last_statement = self.chatbot.get_last_statement() - self.assertEqual(last_statement.text, "Test statement 3") + self.assertEqual(last_statement.text, "Test response 2") class ChatterBotResponseTests(ChatBotTestCase): @@ -58,12 +55,12 @@ def test_statement_saved_empty_database(self): def test_statement_added_to_recent_response_list(self): """ - A new input statement should be added to the recent response list. + An input statement should be added to the recent response list. """ statement_text = "Wow!" response = self.chatbot.get_response(statement_text) - self.assertIn(statement_text, self.chatbot.recent_statements) + self.assertIn(statement_text, self.chatbot.recent_statements[0]) self.assertEqual(response, statement_text) def test_response_known(self): diff --git a/tests/test_queues.py b/tests/test_queues.py new file mode 100644 index 000000000..492057831 --- /dev/null +++ b/tests/test_queues.py @@ -0,0 +1,25 @@ +from unittest import TestCase +from chatterbot.utils.queues import ResponseQueue + + +class ResponseQueueTests(TestCase): + + def setUp(self): + self.queue = ResponseQueue(maxsize=2) + + def test_append(self): + self.queue.append(0) + self.assertIn(0, self.queue) + + def test_contains(self): + self.queue.queue.append(0) + self.assertIn(0, self.queue) + + def test_maxsize(self): + self.queue.append(0) + self.queue.append(1) + self.queue.append(2) + + self.assertNotIn(0, self.queue) + self.assertIn(1, self.queue) + self.assertIn(2, self.queue)