diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index b4f1a18af..efb7a546c 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -2,7 +2,7 @@ from .adapters.logic import LogicAdapter, MultiLogicAdapter from .adapters.input import InputAdapter from .adapters.output import OutputAdapter -from .conversation import Statement +from .conversation import Statement, Response from .utils.queues import ResponseQueue from .utils.module_loading import import_module @@ -143,7 +143,9 @@ def get_response(self, input_item): previous_statement = self.get_last_response_statement() if previous_statement: - input_statement.add_response(previous_statement) + input_statement.add_response( + Response(previous_statement.text) + ) # Update the database after selecting a response self.storage.update(input_statement) diff --git a/chatterbot/conversation/__init__.py b/chatterbot/conversation/__init__.py index cb328a46f..71af608cb 100644 --- a/chatterbot/conversation/__init__.py +++ b/chatterbot/conversation/__init__.py @@ -1,2 +1,2 @@ from .statement import Statement -from .statement import Response +from .response import Response diff --git a/chatterbot/conversation/response.py b/chatterbot/conversation/response.py new file mode 100644 index 000000000..0c862c193 --- /dev/null +++ b/chatterbot/conversation/response.py @@ -0,0 +1,31 @@ +class Response(object): + """ + A response represents an entity which response to a statement. + """ + + def __init__(self, text, **kwargs): + self.text = text + self.occurrence = kwargs.get("occurrence", 1) + + def __str__(self): + return self.text + + def __repr__(self): + return "" % (self.text) + + def __eq__(self, other): + if not other: + return False + + if isinstance(other, Response): + return self.text == other.text + + return self.text == other + + def serialize(self): + data = {} + + data["text"] = self.text + data["occurrence"] = self.occurrence + + return data diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index 0fb4d2f2a..7dc83f1ec 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -1,3 +1,6 @@ +from .response import Response + + class Statement(object): """ A statement represents a single spoken entity, sentence or @@ -41,6 +44,14 @@ def add_response(self, response): """ Add the response to the list if it does not already exist. """ + if not isinstance(response, Response): + raise Statement.InvalidTypeException( + 'A {} was recieved when a {} instance was expected'.format( + type(response), + type(Response('')) + ) + ) + updated = False for index in range(0, len(self.in_response_to)): if response.text == self.in_response_to[index].text: @@ -86,35 +97,10 @@ def serialize(self): return data + class InvalidTypeException(Exception): -class Response(object): - """ - A response represents an entity which response to a statement. - """ - - def __init__(self, text, **kwargs): - self.text = text - self.occurrence = kwargs.get("occurrence", 1) - - def __str__(self): - return self.text + def __init__(self, value='Recieved an unexpected value type.'): + self.value = value - def __repr__(self): - return "" % (self.text) - - def __eq__(self, other): - if not other: - return False - - if isinstance(other, Response): - return self.text == other.text - - return self.text == other - - def serialize(self): - data = {} - - data["text"] = self.text - data["occurrence"] = self.occurrence - - return data + def __str__(self): + return repr(self.value) diff --git a/chatterbot/training/trainers.py b/chatterbot/training/trainers.py index 01287d244..372206713 100644 --- a/chatterbot/training/trainers.py +++ b/chatterbot/training/trainers.py @@ -1,4 +1,4 @@ -from chatterbot.conversation import Statement +from chatterbot.conversation import Statement, Response from chatterbot.corpus import Corpus @@ -30,7 +30,9 @@ def train(self, conversation): previous_statement = statement_history[-1] if previous_statement: - statement.add_response(previous_statement) + statement.add_response( + Response(previous_statement.text) + ) statement_history.append(statement) self.storage.update(statement) diff --git a/tests/conversation_tests/test_statements.py b/tests/conversation_tests/test_statements.py index 8ad22fb9a..9bf94dafc 100644 --- a/tests/conversation_tests/test_statements.py +++ b/tests/conversation_tests/test_statements.py @@ -68,3 +68,7 @@ def test_occurrence_count_incremented(self): self.assertEqual(len(self.statement.in_response_to), 1) self.assertEqual(self.statement.in_response_to[0].occurrence, 2) + + def test_add_non_response(self): + with self.assertRaises(Statement.InvalidTypeException): + self.statement.add_response(Statement("Blah")) diff --git a/tests/storage_adapter_tests/test_jsondb_adapter.py b/tests/storage_adapter_tests/test_jsondb_adapter.py index 4111efd18..a7d9e0a3e 100644 --- a/tests/storage_adapter_tests/test_jsondb_adapter.py +++ b/tests/storage_adapter_tests/test_jsondb_adapter.py @@ -80,7 +80,7 @@ def test_update_modifies_existing_statement(self): # Update the statement value statement.add_response( - Statement("New response") + Response("New response") ) self.adapter.update(statement) @@ -365,7 +365,7 @@ def test_update_does_not_modify_existing_statement(self): self.adapter.read_only = True statement.add_response( - Statement("New response") + Response("New response") ) self.adapter.update(statement)