From bc5c3a2d966f54a4718c922ae6b37eeb582ba2a6 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Thu, 29 Dec 2016 16:53:04 -0500 Subject: [PATCH] Begin storage overhaul --- chatterbot/conversation/session.py | 3 - chatterbot/conversation/statement.py | 41 +++++----- chatterbot/ext/django_chatterbot/models.py | 95 +++------------------- chatterbot/storage/jsonfile.py | 39 ++++----- chatterbot/storage/mongodb.py | 4 +- 5 files changed, 52 insertions(+), 130 deletions(-) diff --git a/chatterbot/conversation/session.py b/chatterbot/conversation/session.py index cdb3751ec..b395803b4 100644 --- a/chatterbot/conversation/session.py +++ b/chatterbot/conversation/session.py @@ -1,5 +1,4 @@ import uuid -from chatterbot.queues import ResponseQueue class StatementManager(object): @@ -25,7 +24,6 @@ def add(self, statement): """ Add a statement to the conversation. """ - statement.conversation_id = self.conversation_id self.storage.update(statement) def count(self): @@ -60,7 +58,6 @@ def get_last_response_statement(self): """ statements = self.statements.all() if statements: - # Return the latest output statement (This should be ordering them by date to get the latest) return statements[1] return None diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index 5051c70bc..8999c0b64 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -3,6 +3,26 @@ from datetime import datetime +class StatementSerializer(object): + + def serialize(self, obj): + """ + :returns: A dictionary representation of the statement object. + :rtype: dict + """ + data = {} + + data['text'] = obj.text + data['in_response_to'] = {'text': obj.in_response_to.text} + data['created_at'] = obj.created_at + data['extra_data'] = obj.extra_data + + return data + + def deserialize(self, data): + pass + + class Statement(object): """ A statement represents a single spoken entity, sentence or @@ -11,7 +31,6 @@ class Statement(object): def __init__(self, text, **kwargs): self.text = text - self.conversation_id = kwargs.pop('conversation_id', None) self.in_response_to = kwargs.pop('in_response_to', []) # The date and time that this statement was created at @@ -120,24 +139,8 @@ def get_response_count(self, statement): return 0 def serialize(self): - """ - :returns: A dictionary representation of the statement object. - :rtype: dict - """ - data = {} - - data['text'] = self.text - data['in_response_to'] = [] - data['created_at'] = self.created_at - data['extra_data'] = self.extra_data - - if self.conversation_id: - data['conversation_id'] = self.conversation_id - - for response in self.in_response_to: - data['in_response_to'].append(response.serialize()) - - return data + serializer = StatementSerializer() + return serializer.serialize(self) @property def response_statement_cache(self): diff --git a/chatterbot/ext/django_chatterbot/models.py b/chatterbot/ext/django_chatterbot/models.py index add37506a..95a7d7199 100644 --- a/chatterbot/ext/django_chatterbot/models.py +++ b/chatterbot/ext/django_chatterbot/models.py @@ -9,19 +9,11 @@ class Statement(models.Model): """ text = models.CharField( - unique=True, blank=False, null=False, max_length=255 ) - conversation = models.ForeignKey( - 'Conversation', - related_name='statements', - blank=True, - null=True - ) - created_at = models.DateTimeField( default=timezone.now, help_text='The date and time that this statement was created at.' @@ -29,6 +21,12 @@ class Statement(models.Model): extra_data = models.CharField(max_length=500) + response = models.OneToOneField( + 'Statement', + related_name='in_response_to', + null=True + ) + def __str__(self): if len(self.text.strip()) > 60: return '{}...'.format(self.text[:57]) @@ -42,13 +40,6 @@ def __init__(self, *args, **kwargs): # Responses to be saved if the statement is updated with the storage adapter self.response_statement_cache = [] - @property - def in_response_to(self): - """ - Return the response objects that are for this statement. - """ - return Response.objects.filter(statement=self) - def add_extra_data(self, key, value): """ Add extra data to the extra_data field. @@ -69,39 +60,6 @@ def add_response(self, statement): """ self.response_statement_cache.append(statement) - def remove_response(self, response_text): - """ - Removes a response from the statement's response list based - on the value of the response text. - - :param response_text: The text of the response to be removed. - :type response_text: str - """ - is_deleted = False - response = self.in_response.filter(response__text=response_text) - - if response.exists(): - is_deleted = True - - return is_deleted - - def get_response_count(self, statement): - """ - Find the number of times that the statement has been used - as a response to the current statement. - - :param statement: The statement object to get the count for. - :type statement: chatterbot.conversation.statement.Statement - - :returns: Return the number of times the statement has been used as a response. - :rtype: int - """ - try: - response = self.in_response.get(response__text=statement.text) - return response.occurrence - except Response.DoesNotExist: - return 0 - def serialize(self): """ :returns: A dictionary representation of the statement object. @@ -114,51 +72,24 @@ def serialize(self): self.extra_data = '{}' data['text'] = self.text - data['in_response_to'] = [] + data['in_response_to'] = {'text': self.in_response_to.text} data['created_at'] = self.created_at data['extra_data'] = json.loads(self.extra_data) - for response in self.in_response.all(): - data['in_response_to'].append(response.serialize()) - return data -class Response(models.Model): +class Conversation(models.Model): """ - Connection between a response and the statement that triggered it. - - Comparble to a ManyToMany "through" table, but without the M2M indexing/relations. - The text and number of times the response has occurred are stored. + A sequence of statements representing a conversation. """ - statement = models.ForeignKey( + root = models.OneToOneField( 'Statement', - related_name='in_response' + related_name='conversation', + help_text='The initiating statement in a conversation.' ) - response = models.ForeignKey( - 'Statement', - related_name='responses' - ) - - unique_together = (('statement', 'response'),) - - occurrence = models.PositiveIntegerField(default=1) - - def __str__(self): - statement = self.statement.text - response = self.response.text - return '{} => {}'.format( - statement if len(statement) <= 20 else statement[:17] + '...', - response if len(response) <= 40 else response[:37] + '...' - ) - - -class Conversation(models.Model): - """ - A sequence of statements representing a conversation. - """ - def __str__(self): return str(self.id) + diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index 6a7657094..bb65873d9 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -33,14 +33,13 @@ def __init__(self, **kwargs): database_path = self.kwargs.get('database', 'database.db') self.database = Database(database_path) - self.adapter_supports_queries = False + # Create the statements document as an empty list + self.database['statements'] = [] - def _keys(self): - # The value has to be cast as a list for Python 3 compatibility - return list(self.database[0].keys()) + self.adapter_supports_queries = False def count(self): - return len(self._keys()) + return len(self.database['statements'].keys()) def find(self, statement_text): values = self.database.data(key=statement_text) @@ -135,14 +134,10 @@ def filter(self, **kwargs): order_by = kwargs.pop('order_by', None) - for key in self._keys(): - values = self.database.data(key=key) - - # Add the text attribute to the values - values['text'] = key + for statement in self.database['statements']: - if self._all_kwargs_match_values(kwargs, values): - results.append(self.json_to_object(values)) + if self._all_kwargs_match_values(kwargs, statement): + results.append(self.json_to_object(statement)) if order_by: @@ -160,18 +155,16 @@ def update(self, statement, **kwargs): """ # Do not alter the database unless writing is enabled if not self.read_only: - data = statement.serialize() + statements = self.database['statements'] - # Remove the text key from the data - del data['text'] - self.database.data(key=statement.text, value=data) + statements.append(statement.serialize()) # Make sure that an entry for each response exists - for response_statement in statement.in_response_to: - response = self.find(response_statement.text) - if not response: - response = self.Statement(response_statement.text) - self.update(response) + if statement.in_response_to: + response = self.Statement(statement.in_response_to.text) + statements.append(statement.in_response_to.serialize()) + + self.database.data(key='statements', value=statements) return statement @@ -181,8 +174,8 @@ def get_random(self): if self.count() < 1: raise self.EmptyDatabaseException() - statement = choice(self._keys()) - return self.find(statement) + statement = choice(self.database['statements']) + return statement def drop(self): """ diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index e3af2759d..640d39fdc 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -103,9 +103,6 @@ def __init__(self, **kwargs): # The mongo collection of statement documents self.statements = self.database['statements'] - # Set a requirement for the text attribute to be unique - self.statements.create_index('text', unique=True) - self.base_query = Query() def count(self): @@ -211,6 +208,7 @@ def update(self, statement, **kwargs): from pymongo.errors import BulkWriteError force = kwargs.get('force', False) + # Do not alter the database unless writing is enabled if force or not self.read_only: data = statement.serialize()