diff --git a/chatterbot/ext/django_chatterbot/settings.py b/chatterbot/ext/django_chatterbot/settings.py index 3e7b243c3..9d749afaa 100644 --- a/chatterbot/ext/django_chatterbot/settings.py +++ b/chatterbot/ext/django_chatterbot/settings.py @@ -11,7 +11,6 @@ 'storage_adapter': 'chatterbot.storage.DjangoStorageAdapter', 'input_adapter': 'chatterbot.input.VariableInputTypeAdapter', 'output_adapter': 'chatterbot.output.OutputAdapter', - 'use_django_models': True, 'django_app_name': 'django_chatterbot' } diff --git a/chatterbot/storage/django_storage.py b/chatterbot/storage/django_storage.py index 2b6a42625..32965093c 100644 --- a/chatterbot/storage/django_storage.py +++ b/chatterbot/storage/django_storage.py @@ -13,14 +13,24 @@ def __init__(self, **kwargs): self.adapter_supports_queries = False self.django_app_name = kwargs.get('django_app_name', 'django_chatterbot') - def count(self): + def get_statement_model(self): + from django.apps import apps + return apps.get_model(self.django_app_name, 'Statement') + + def get_response_model(self): + from django.apps import apps + return apps.get_model(self.django_app_name, 'Response') + + def get_conversation_model(self): from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') + return apps.get_model(self.django_app_name, 'Conversation') + + def count(self): + Statement = self.get_model('statement') return Statement.objects.count() def find(self, statement_text): - from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') + Statement = self.get_model('statement') try: return Statement.objects.get(text=statement_text) except Statement.DoesNotExist as e: @@ -32,9 +42,8 @@ def filter(self, **kwargs): Returns a list of statements in the database that match the parameters specified. """ - from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') from django.db.models import Q + Statement = self.get_model('statement') order = kwargs.pop('order_by', None) @@ -80,9 +89,8 @@ def update(self, statement): """ Update the provided statement. """ - from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') - Response = apps.get_model(self.django_app_name, 'Response') + Statement = self.get_model('statement') + Response = self.get_model('response') response_statement_cache = statement.response_statement_cache @@ -109,8 +117,7 @@ def get_random(self): """ Returns a random statement from the database """ - from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') + Statement = self.get_model('statement') return Statement.objects.order_by('?').first() def remove(self, statement_text): @@ -119,11 +126,10 @@ def remove(self, statement_text): Removes any responses from statements if the response text matches the input text. """ - from django.apps import apps from django.db.models import Q - Statement = apps.get_model(self.django_app_name, 'Statement') - Response = apps.get_model(self.django_app_name, 'Response') + Statement = self.get_model('statement') + Response = self.get_model('response') statements = Statement.objects.filter(text=statement_text) @@ -139,9 +145,7 @@ def get_latest_response(self, conversation_id): Returns the latest response in a conversation if it exists. Returns None if a matching conversation cannot be found. """ - from django.apps import apps - - Response = apps.get_model(self.django_app_name, 'Response') + Response = self.get_model('response') response = Response.objects.filter( conversations__id=conversation_id @@ -158,8 +162,7 @@ def create_conversation(self): """ Create a new conversation. """ - from django.apps import apps - Conversation = apps.get_model(self.django_app_name, 'Conversation') + Conversation = self.get_model('conversation') conversation = Conversation.objects.create() return conversation.id @@ -167,10 +170,8 @@ def add_to_conversation(self, conversation_id, statement, response): """ Add the statement and response to the conversation. """ - from django.apps import apps - - Statement = apps.get_model(self.django_app_name, 'Statement') - Response = apps.get_model(self.django_app_name, 'Response') + Statement = self.get_model('statement') + Response = self.get_model('response') first_statement = Statement.objects.get(text=statement.text) first_response = Statement.objects.get(text=response.text) @@ -186,11 +187,9 @@ def drop(self): """ Remove all data from the database. """ - from django.apps import apps - - Statement = apps.get_model(self.django_app_name, 'Statement') - Response = apps.get_model(self.django_app_name, 'Response') - Conversation = apps.get_model(self.django_app_name, 'Conversation') + Statement = self.get_model('statement') + Response = self.get_model('response') + Conversation = self.get_model('conversation') Statement.objects.all().delete() Response.objects.all().delete() @@ -203,9 +202,8 @@ def get_response_statements(self): in_response_to field. Otherwise, the logic adapter may find a closest matching statement that does not have a known response. """ - from django.apps import apps - Statement = apps.get_model(self.django_app_name, 'Statement') - Response = apps.get_model(self.django_app_name, 'Response') + Statement = self.get_model('statement') + Response = self.get_model('response') responses = Response.objects.all() diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index d41ede719..567db099c 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -71,7 +71,8 @@ def deserialize_responses(self, response_list): Takes the list of response items and returns the list converted to Response objects. """ - proxy_statement = self.Statement('') + Statement = self.get_model('statement') + proxy_statement = Statement('') for response in response_list: data = response.copy() @@ -88,6 +89,7 @@ def json_to_object(self, statement_data): """ Converts a dictionary-like object to a Statement object. """ + Statement = self.get_model('statement') # Don't modify the referenced object statement_data = statement_data.copy() @@ -100,7 +102,7 @@ def json_to_object(self, statement_data): # Remove the text attribute from the values text = statement_data.pop('text') - return self.Statement(text, **statement_data) + return Statement(text, **statement_data) def _all_kwargs_match_values(self, kwarguments, values): for kwarg in kwarguments: @@ -160,6 +162,7 @@ def update(self, statement): """ Update a statement in the database. """ + Statement = self.get_model('statement') data = statement.serialize() # Remove the text key from the data @@ -170,7 +173,7 @@ def update(self, statement): for response_statement in statement.in_response_to: response = self.find(response_statement.text) if not response: - response = self.Statement(response_statement.text) + response = Statement(response_statement.text) self.update(response) return statement diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index 504c59af0..aaa9e4937 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -117,6 +117,7 @@ def count(self): return self.statements.count() def find(self, statement_text): + Statement = self.get_model('statement') query = self.base_query.statement_text_equals(statement_text) values = self.statements.find_one(query.value()) @@ -131,14 +132,15 @@ def find(self, statement_text): values.get('in_response_to', []) ) - return self.Statement(statement_text, **values) + return Statement(statement_text, **values) def deserialize_responses(self, response_list): """ Takes the list of response items and returns the list converted to Response objects. """ - proxy_statement = self.Statement('') + Statement = self.get_model('statement') + proxy_statement = Statement('') for response in response_list: text = response['text'] @@ -155,6 +157,7 @@ def mongo_to_object(self, statement_data): Return Statement object when given data returned from Mongo DB. """ + Statement = self.get_model('statement') statement_text = statement_data['text'] del statement_data['text'] @@ -162,7 +165,7 @@ def mongo_to_object(self, statement_data): statement_data.get('in_response_to', []) ) - return self.Statement(statement_text, **statement_data) + return Statement(statement_text, **statement_data) def filter(self, **kwargs): """ diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index a9308de19..ecb61d546 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -1,5 +1,4 @@ import logging -import os class StorageAdapter(object): @@ -17,23 +16,28 @@ def __init__(self, base_query=None, *args, **kwargs): self.adapter_supports_queries = True self.base_query = None - @property - def Statement(self): + def get_model(self, model_name): """ - Create a storage-aware statement. + Return the model class for a given model name. """ - if 'DJANGO_SETTINGS_MODULE' in os.environ: - django_project = __import__(os.environ['DJANGO_SETTINGS_MODULE']) - if 'use_django_models' in django_project.settings.CHATTERBOT: - if django_project.settings.CHATTERBOT['use_django_models'] is True: - from django.apps import apps - Statement = apps.get_model(django_project.settings.CHATTERBOT['django_app_name'], 'Statement') - return Statement + # The string must be lowercase + model_name = model_name.lower() + get_model_method = getattr(self, 'get_%s_model' % (model_name, )) + + return get_model_method() + + def get_statement_model(self): + """ + Return the class for the statement model. + """ from chatterbot.conversation.statement import Statement + + # Create a storage-aware statement statement = Statement statement.storage = self + return statement def generate_base_query(self, chatterbot, session_id): diff --git a/docs/django/settings.rst b/docs/django/settings.rst index e5098c4e9..f836018bf 100644 --- a/docs/django/settings.rst +++ b/docs/django/settings.rst @@ -24,6 +24,4 @@ Any setting that gets set in the CHATTERBOT dictionary will be passed to the cha Additional Django settings ========================== -- :code:`use_django_models` [default: True] Use the Django models for storing learned conversation data. - If set to False, ChatterBot's non-Django objects will be used. -- :code:`django_app_name` [default: 'django_chatterbot'] The Django app name to look up the models from. \ No newline at end of file +- ``django_app_name`` [default: 'django_chatterbot'] The Django app name to look up the models from. \ No newline at end of file