diff --git a/chatterbot/adapters/logic/base_match.py b/chatterbot/adapters/logic/base_match.py index 69947bcee..763e71bfc 100644 --- a/chatterbot/adapters/logic/base_match.py +++ b/chatterbot/adapters/logic/base_match.py @@ -23,20 +23,7 @@ def has_storage_context(self): """ return self.context and self.context.storage - def get_available_statements(self, statement_list=None): - from chatterbot.conversation.utils import get_response_statements - - if statement_list: - statement_list = get_response_statements(statement_list) - - # Check if the list is empty - if not statement_list and self.has_storage_context: - all_statements = self.context.storage.filter() - statement_list = get_response_statements(all_statements) - - return statement_list - - def get(self, input_statement, statement_list=None): + def get(self, input_statement): """ This method should be overridden with one to select a match based on the input statement. diff --git a/chatterbot/adapters/logic/closest_match.py b/chatterbot/adapters/logic/closest_match.py index b034d5332..07d87eaa9 100644 --- a/chatterbot/adapters/logic/closest_match.py +++ b/chatterbot/adapters/logic/closest_match.py @@ -12,12 +12,12 @@ class ClosestMatchAdapter(BaseMatchAdapter): of each statement. """ - def get(self, input_statement, statement_list=None): + def get(self, input_statement): """ Takes a statement string and a list of statement strings. Returns the closest matching statement from the list. """ - statement_list = self.get_available_statements(statement_list) + statement_list = self.context.storage.get_response_statements() if not statement_list: if self.has_storage_context: diff --git a/chatterbot/adapters/logic/closest_meaning.py b/chatterbot/adapters/logic/closest_meaning.py index b6817f9ba..575f588d6 100644 --- a/chatterbot/adapters/logic/closest_meaning.py +++ b/chatterbot/adapters/logic/closest_meaning.py @@ -75,12 +75,12 @@ def get_similarity(self, string1, string2): return total_similarity - def get(self, input_statement, statement_list=None): + def get(self, input_statement): """ Takes a statement string and a list of statement strings. Returns the closest matching statement from the list. """ - statement_list = self.get_available_statements(statement_list) + statement_list = self.context.storage.get_response_statements() if not statement_list: if self.has_storage_context: diff --git a/chatterbot/adapters/storage/mongodb.py b/chatterbot/adapters/storage/mongodb.py index 88db7643e..be01f9cdb 100644 --- a/chatterbot/adapters/storage/mongodb.py +++ b/chatterbot/adapters/storage/mongodb.py @@ -172,6 +172,33 @@ def remove(self, statement_text): self.statements.remove({'text': statement_text}) + def get_response_statements(self): + """ + Return only statements that are in response to another statement. + A statement must exist which lists the closest matching statement in the + in_response_to field. Otherwise, the logic adapter may find a closest + matching statement that does not have a known response. + """ + response_query = self.statements.distinct('in_response_to.text') + statement_query = self.statements.find({ + 'text': { + '$in': response_query + } + }) + + statement_list = list(statement_query) + + statement_objects = [] + + for statement in statement_list: + values = dict(statement) + statement_text = values['text'] + + del(values['text']) + statement_objects.append(Statement(statement_text, **values)) + + return statement_objects + def drop(self): """ Remove the database. diff --git a/chatterbot/adapters/storage/storage_adapter.py b/chatterbot/adapters/storage/storage_adapter.py index 32bf4a249..0c9e53d63 100644 --- a/chatterbot/adapters/storage/storage_adapter.py +++ b/chatterbot/adapters/storage/storage_adapter.py @@ -62,6 +62,32 @@ def drop(self): """ raise self.AdapterMethodNotImplementedError() + def get_response_statements(self): + """ + Return only statements that are in response to another statement. + A statement must exist which lists the closest matching statement in the + in_response_to field. Otherwise, the logic adapter may find a closest + matching statement that does not have a known response. + + This method may be overridden by a child class to provide more a + efficient method to get these results. + """ + statement_list = self.filter() + + responses = set() + to_remove = list() + for statement in statement_list: + for response in statement.in_response_to: + responses.add(response.text) + for statement in statement_list: + if statement.text not in responses: + to_remove.append(statement) + + for statement in to_remove: + statement_list.remove(statement) + + return statement_list + class EmptyDatabaseException(Exception): def __init__(self, message="The database currently contains no entries. At least one entry is expected. You may need to train your chat bot to populate your database."): diff --git a/chatterbot/conversation/utils.py b/chatterbot/conversation/utils.py deleted file mode 100644 index 916b3df1b..000000000 --- a/chatterbot/conversation/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -def get_response_statements(statement_list): - """ - Filter out all statements that are not in response to another statement. - A statement must exist which lists the closest matching statement in the - in_response_to field. Otherwise, the logic adapter may find a closest - matching statement that does not have a known response. - """ - responses = set() - to_remove = list() - for statement in statement_list: - for response in statement.in_response_to: - responses.add(response.text) - for statement in statement_list: - if statement.text not in responses: - to_remove.append(statement) - - for statement in to_remove: - statement_list.remove(statement) - - return statement_list diff --git a/tests/conversation_tests/test_utils.py b/tests/conversation_tests/test_utils.py deleted file mode 100644 index 6f53c1da2..000000000 --- a/tests/conversation_tests/test_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from unittest import TestCase -from chatterbot.conversation.utils import get_response_statements -from chatterbot.conversation import Statement, Response - - -class ConversationUtilsTests(TestCase): - - def test_get_statements_with_known_responses(self): - statement_list = [ - Statement("What... is your quest?"), - Statement("This is a phone."), - Statement("A what?", in_response_to=[Response("This is a phone.")]), - Statement("A phone.", in_response_to=[Response("A what?")]) - ] - - responses = get_response_statements(statement_list) - - self.assertEqual(len(responses), 2) - self.assertIn("This is a phone.", responses) - self.assertIn("A what?", responses) diff --git a/tests/logic_adapter_tests/test_closest_match.py b/tests/logic_adapter_tests/test_closest_match.py index 27a250c42..367259727 100644 --- a/tests/logic_adapter_tests/test_closest_match.py +++ b/tests/logic_adapter_tests/test_closest_match.py @@ -1,19 +1,34 @@ from unittest import TestCase +from mock import MagicMock, Mock from chatterbot.adapters.logic import ClosestMatchAdapter +from chatterbot.adapters.storage import StorageAdapter from chatterbot.conversation import Statement, Response +class MockContext(object): + def __init__(self): + self.storage = StorageAdapter() + + self.storage.get_random = Mock( + side_effect=ClosestMatchAdapter.EmptyDatasetException() + ) + + class ClosestMatchAdapterTests(TestCase): def setUp(self): self.adapter = ClosestMatchAdapter() + # Add a mock storage adapter to the context + self.adapter.set_context(MockContext()) + def test_no_choices(self): - possible_choices = [] + self.adapter.context.storage.filter = MagicMock(return_value=[]) + statement = Statement("What is your quest?") with self.assertRaises(ClosestMatchAdapter.EmptyDatasetException): - self.adapter.get(statement, possible_choices) + self.adapter.get(statement) def test_get_closest_statement(self): """ @@ -29,9 +44,11 @@ def test_get_closest_statement(self): Statement("Yuck, black licorice jelly beans.", in_response_to=[Response("What is the meaning of life?")]), Statement("I hear you are going on a quest?", in_response_to=[Response("Who do you love?")]), ] + self.adapter.context.storage.filter = MagicMock(return_value=possible_choices) + statement = Statement("What is your quest?") - confidence, match = self.adapter.get(statement, possible_choices) + confidence, match = self.adapter.get(statement) self.assertEqual("What... is your quest?", match) @@ -39,12 +56,10 @@ def test_confidence_exact_match(self): possible_choices = [ Statement("What is your quest?", in_response_to=[Response("What is your quest?")]) ] + self.adapter.context.storage.filter = MagicMock(return_value=possible_choices) statement = Statement("What is your quest?") - - confidence, match = self.adapter.get( - statement, possible_choices - ) + confidence, match = self.adapter.get(statement) self.assertEqual(confidence, 1) @@ -52,12 +67,10 @@ def test_confidence_half_match(self): possible_choices = [ Statement("xxyy", in_response_to=[Response("xxyy")]) ] + self.adapter.context.storage.filter = MagicMock(return_value=possible_choices) statement = Statement("wwxx") - - confidence, match = self.adapter.get( - statement, possible_choices - ) + confidence, match = self.adapter.get(statement) self.assertEqual(confidence, 0.5) @@ -65,11 +78,9 @@ def test_confidence_no_match(self): possible_choices = [ Statement("xxx", in_response_to=[Response("xxx")]) ] + self.adapter.context.storage.filter = MagicMock(return_value=possible_choices) statement = Statement("yyy") - - confidence, match = self.adapter.get( - statement, possible_choices - ) + confidence, match = self.adapter.get(statement) self.assertEqual(confidence, 0) diff --git a/tests/logic_adapter_tests/test_closest_meaning.py b/tests/logic_adapter_tests/test_closest_meaning.py index 460bea7ef..27716d62c 100644 --- a/tests/logic_adapter_tests/test_closest_meaning.py +++ b/tests/logic_adapter_tests/test_closest_meaning.py @@ -1,19 +1,33 @@ from unittest import TestCase +from mock import MagicMock, Mock from chatterbot.adapters.logic import ClosestMeaningAdapter +from chatterbot.adapters.storage import StorageAdapter from chatterbot.conversation import Statement, Response +class MockContext(object): + def __init__(self): + self.storage = StorageAdapter() + + self.storage.get_random = Mock( + side_effect=ClosestMeaningAdapter.EmptyDatasetException() + ) + + class ClosestMeaningAdapterTests(TestCase): def setUp(self): self.adapter = ClosestMeaningAdapter() + # Add a mock storage adapter to the context + self.adapter.set_context(MockContext()) + def test_no_choices(self): - possible_choices = [] + self.adapter.context.storage.filter = MagicMock(return_value=[]) statement = Statement("Hello") with self.assertRaises(ClosestMeaningAdapter.EmptyDatasetException): - self.adapter.get(statement, possible_choices) + self.adapter.get(statement) def test_get_closest_statement(self): """ @@ -26,9 +40,10 @@ def test_get_closest_statement(self): Statement("This is a beautiful swamp.", in_response_to=[Response("This is a beautiful swamp.")]), Statement("It smells like swamp.", in_response_to=[Response("It smells like swamp.")]) ] - statement = Statement("This is a lovely swamp.") + self.adapter.context.storage.filter = MagicMock(return_value=possible_choices) - confidence, match = self.adapter.get(statement, possible_choices) + statement = Statement("This is a lovely swamp.") + confidence, match = self.adapter.get(statement) self.assertEqual("This is a lovely bog.", match) diff --git a/tests/logic_adapter_tests/test_logic_adapter.py b/tests/logic_adapter_tests/test_logic_adapter.py index 01f99ed7e..a3c09d5de 100644 --- a/tests/logic_adapter_tests/test_logic_adapter.py +++ b/tests/logic_adapter_tests/test_logic_adapter.py @@ -1,6 +1,7 @@ from unittest import TestCase from chatterbot.adapters.logic import LogicAdapter + class LogicAdapterTestCase(TestCase): """ This test case is for the LogicAdapter base class. diff --git a/tests/storage_adapter_tests/test_jsondb_adapter.py b/tests/storage_adapter_tests/test_jsondb_adapter.py index 254d190c5..4111efd18 100644 --- a/tests/storage_adapter_tests/test_jsondb_adapter.py +++ b/tests/storage_adapter_tests/test_jsondb_adapter.py @@ -184,6 +184,27 @@ def test_remove_response(self): self.assertEqual(results, []) + def test_get_response_statements(self): + """ + Test that we are able to get a list of only statements + that are known to be in response to another statement. + """ + statement_list = [ + Statement("What... is your quest?"), + Statement("This is a phone."), + Statement("A what?", in_response_to=[Response("This is a phone.")]), + Statement("A phone.", in_response_to=[Response("A what?")]) + ] + + for statement in statement_list: + self.adapter.update(statement) + + responses = self.adapter.get_response_statements() + + self.assertEqual(len(responses), 2) + self.assertIn("This is a phone.", responses) + self.assertIn("A what?", responses) + class JsonDatabaseAdapterFilterTestCase(JsonAdapterTestCase): diff --git a/tests/storage_adapter_tests/test_mongo_adapter.py b/tests/storage_adapter_tests/test_mongo_adapter.py index 124ebf027..d7280cbcb 100644 --- a/tests/storage_adapter_tests/test_mongo_adapter.py +++ b/tests/storage_adapter_tests/test_mongo_adapter.py @@ -195,6 +195,27 @@ def test_remove_response(self): self.assertEqual(results, []) + def test_get_response_statements(self): + """ + Test that we are able to get a list of only statements + that are known to be in response to another statement. + """ + statement_list = [ + Statement("What... is your quest?"), + Statement("This is a phone."), + Statement("A what?", in_response_to=[Response("This is a phone.")]), + Statement("A phone.", in_response_to=[Response("A what?")]) + ] + + for statement in statement_list: + self.adapter.update(statement) + + responses = self.adapter.get_response_statements() + + self.assertEqual(len(responses), 2) + self.assertIn("This is a phone.", responses) + self.assertIn("A what?", responses) + class MongoAdapterFilterTestCase(MongoAdapterTestCase): diff --git a/tests/storage_adapter_tests/test_storage_adapter.py b/tests/storage_adapter_tests/test_storage_adapter.py index 1abf17e0b..d3ef49725 100644 --- a/tests/storage_adapter_tests/test_storage_adapter.py +++ b/tests/storage_adapter_tests/test_storage_adapter.py @@ -1,5 +1,6 @@ from unittest import TestCase from chatterbot.adapters.storage import StorageAdapter +from chatterbot.conversation import Statement, Response class StorageAdapterTestCase(TestCase): """ @@ -37,6 +38,10 @@ def test_get_random(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): self.adapter.get_random() + def test_get_response_statements(self): + with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): + self.adapter.get_response_statements() + def test_drop(self): with self.assertRaises(StorageAdapter.AdapterMethodNotImplementedError): self.adapter.drop()