Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any database model to be used #1001

Merged
merged 4 commits into from
Sep 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion chatterbot/ext/django_chatterbot/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
'storage_adapter': 'chatterbot.storage.DjangoStorageAdapter',
'input_adapter': 'chatterbot.input.VariableInputTypeAdapter',
'output_adapter': 'chatterbot.output.OutputAdapter',
'use_django_models': True,
'django_app_name': 'django_chatterbot'
}

Expand Down
60 changes: 29 additions & 31 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@ def __init__(self, **kwargs):
self.adapter_supports_queries = False
self.django_app_name = kwargs.get('django_app_name', 'django_chatterbot')

def count(self):
def get_statement_model(self):
from django.apps import apps
return apps.get_model(self.django_app_name, 'Statement')

def get_response_model(self):
from django.apps import apps
return apps.get_model(self.django_app_name, 'Response')

def get_conversation_model(self):
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
return apps.get_model(self.django_app_name, 'Conversation')

def count(self):
Statement = self.get_model('statement')
return Statement.objects.count()

def find(self, statement_text):
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
Statement = self.get_model('statement')
try:
return Statement.objects.get(text=statement_text)
except Statement.DoesNotExist as e:
Expand All @@ -32,9 +42,8 @@ def filter(self, **kwargs):
Returns a list of statements in the database
that match the parameters specified.
"""
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
from django.db.models import Q
Statement = self.get_model('statement')

order = kwargs.pop('order_by', None)

Expand Down Expand Up @@ -80,9 +89,8 @@ def update(self, statement):
"""
Update the provided statement.
"""
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
Response = apps.get_model(self.django_app_name, 'Response')
Statement = self.get_model('statement')
Response = self.get_model('response')

response_statement_cache = statement.response_statement_cache

Expand All @@ -109,8 +117,7 @@ def get_random(self):
"""
Returns a random statement from the database
"""
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
Statement = self.get_model('statement')
return Statement.objects.order_by('?').first()

def remove(self, statement_text):
Expand All @@ -119,11 +126,10 @@ def remove(self, statement_text):
Removes any responses from statements if the response text matches the
input text.
"""
from django.apps import apps
from django.db.models import Q

Statement = apps.get_model(self.django_app_name, 'Statement')
Response = apps.get_model(self.django_app_name, 'Response')
Statement = self.get_model('statement')
Response = self.get_model('response')

statements = Statement.objects.filter(text=statement_text)

Expand All @@ -139,9 +145,7 @@ def get_latest_response(self, conversation_id):
Returns the latest response in a conversation if it exists.
Returns None if a matching conversation cannot be found.
"""
from django.apps import apps

Response = apps.get_model(self.django_app_name, 'Response')
Response = self.get_model('response')

response = Response.objects.filter(
conversations__id=conversation_id
Expand All @@ -158,19 +162,16 @@ def create_conversation(self):
"""
Create a new conversation.
"""
from django.apps import apps
Conversation = apps.get_model(self.django_app_name, 'Conversation')
Conversation = self.get_model('conversation')
conversation = Conversation.objects.create()
return conversation.id

def add_to_conversation(self, conversation_id, statement, response):
"""
Add the statement and response to the conversation.
"""
from django.apps import apps

Statement = apps.get_model(self.django_app_name, 'Statement')
Response = apps.get_model(self.django_app_name, 'Response')
Statement = self.get_model('statement')
Response = self.get_model('response')

first_statement = Statement.objects.get(text=statement.text)
first_response = Statement.objects.get(text=response.text)
Expand All @@ -186,11 +187,9 @@ def drop(self):
"""
Remove all data from the database.
"""
from django.apps import apps

Statement = apps.get_model(self.django_app_name, 'Statement')
Response = apps.get_model(self.django_app_name, 'Response')
Conversation = apps.get_model(self.django_app_name, 'Conversation')
Statement = self.get_model('statement')
Response = self.get_model('response')
Conversation = self.get_model('conversation')

Statement.objects.all().delete()
Response.objects.all().delete()
Expand All @@ -203,9 +202,8 @@ def get_response_statements(self):
in_response_to field. Otherwise, the logic adapter may find a closest
matching statement that does not have a known response.
"""
from django.apps import apps
Statement = apps.get_model(self.django_app_name, 'Statement')
Response = apps.get_model(self.django_app_name, 'Response')
Statement = self.get_model('statement')
Response = self.get_model('response')

responses = Response.objects.all()

Expand Down
21 changes: 18 additions & 3 deletions chatterbot/storage/jsonfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def __init__(self, **kwargs):

self.adapter_supports_queries = False

def get_statement_model(self):
"""
Return the class for the statement model.
"""
from chatterbot.conversation.statement import Statement

# Create a storage-aware statement
statement = Statement
statement.storage = self

return statement

def _keys(self):
# The value has to be cast as a list for Python 3 compatibility
return list(self.database[0].keys())
Expand Down Expand Up @@ -71,7 +83,8 @@ def deserialize_responses(self, response_list):
Takes the list of response items and returns
the list converted to Response objects.
"""
proxy_statement = self.Statement('')
Statement = self.get_model('statement')
proxy_statement = Statement('')

for response in response_list:
data = response.copy()
Expand All @@ -88,6 +101,7 @@ def json_to_object(self, statement_data):
"""
Converts a dictionary-like object to a Statement object.
"""
Statement = self.get_model('statement')

# Don't modify the referenced object
statement_data = statement_data.copy()
Expand All @@ -100,7 +114,7 @@ def json_to_object(self, statement_data):
# Remove the text attribute from the values
text = statement_data.pop('text')

return self.Statement(text, **statement_data)
return Statement(text, **statement_data)

def _all_kwargs_match_values(self, kwarguments, values):
for kwarg in kwarguments:
Expand Down Expand Up @@ -160,6 +174,7 @@ def update(self, statement):
"""
Update a statement in the database.
"""
Statement = self.get_model('statement')
data = statement.serialize()

# Remove the text key from the data
Expand All @@ -170,7 +185,7 @@ def update(self, statement):
for response_statement in statement.in_response_to:
response = self.find(response_statement.text)
if not response:
response = self.Statement(response_statement.text)
response = Statement(response_statement.text)
self.update(response)

return statement
Expand Down
35 changes: 31 additions & 4 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from chatterbot.storage import StorageAdapter
from chatterbot.conversation import Response


class Query(object):
Expand Down Expand Up @@ -113,10 +112,35 @@ def __init__(self, **kwargs):

self.base_query = Query()

def get_statement_model(self):
"""
Return the class for the statement model.
"""
from chatterbot.conversation.statement import Statement

# Create a storage-aware statement
statement = Statement
statement.storage = self

return statement

def get_response_model(self):
"""
Return the class for the response model.
"""
from chatterbot.conversation.response import Response

# Create a storage-aware response
response = Response
response.storage = self

return response

def count(self):
return self.statements.count()

def find(self, statement_text):
Statement = self.get_model('statement')
query = self.base_query.statement_text_equals(statement_text)

values = self.statements.find_one(query.value())
Expand All @@ -131,14 +155,16 @@ def find(self, statement_text):
values.get('in_response_to', [])
)

return self.Statement(statement_text, **values)
return Statement(statement_text, **values)

def deserialize_responses(self, response_list):
"""
Takes the list of response items and returns
the list converted to Response objects.
"""
proxy_statement = self.Statement('')
Statement = self.get_model('statement')
Response = self.get_model('response')
proxy_statement = Statement('')

for response in response_list:
text = response['text']
Expand All @@ -155,14 +181,15 @@ def mongo_to_object(self, statement_data):
Return Statement object when given data
returned from Mongo DB.
"""
Statement = self.get_model('statement')
statement_text = statement_data['text']
del statement_data['text']

statement_data['in_response_to'] = self.deserialize_responses(
statement_data.get('in_response_to', [])
)

return self.Statement(statement_text, **statement_data)
return Statement(statement_text, **statement_data)

def filter(self, **kwargs):
"""
Expand Down
43 changes: 34 additions & 9 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from chatterbot.storage import StorageAdapter


Expand Down Expand Up @@ -79,11 +78,32 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
# ChatterBot's internal query builder is not yet supported for this adapter
self.adapter_supports_queries = False

def get_statement_model(self):
"""
Return the statement model.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
return Statement

def get_response_model(self):
"""
Return the response model.
"""
from chatterbot.ext.sqlalchemy_app.models import Response
return Response

def get_conversation_model(self):
"""
Return the conversation model.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation
return Conversation

def count(self):
"""
Return the number of entries in the database.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

session = self.Session()
statement_count = session.query(Statement).count()
Expand All @@ -96,7 +116,7 @@ def __statement_filter(self, session, **kwargs):

rtype: query
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

_query = session.query(Statement)
return _query.filter_by(**kwargs)
Expand Down Expand Up @@ -138,7 +158,8 @@ def filter(self, **kwargs):
all listed attributes and in which all values
match for all listed attributes will be returned.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement, Response
Statement = self.get_model('statement')
Response = self.get_model('response')

session = self.Session()

Expand Down Expand Up @@ -199,7 +220,8 @@ def update(self, statement):
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement, Response
Statement = self.get_model('statement')
Response = self.get_model('response')

if statement:
session = self.Session()
Expand Down Expand Up @@ -240,7 +262,7 @@ def create_conversation(self):
"""
Create a new conversation.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation
Conversation = self.get_model('conversation')

session = self.Session()
conversation = Conversation()
Expand All @@ -260,7 +282,8 @@ def add_to_conversation(self, conversation_id, statement, response):
"""
Add the statement and response to the conversation.
"""
from chatterbot.ext.sqlalchemy_app.models import Conversation, Statement
Statement = self.get_model('statement')
Conversation = self.get_model('conversation')

session = self.Session()
conversation = session.query(Conversation).get(conversation_id)
Expand Down Expand Up @@ -296,7 +319,7 @@ def get_latest_response(self, conversation_id):
Returns the latest response in a conversation if it exists.
Returns None if a matching conversation cannot be found.
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
Statement = self.get_model('statement')

session = self.Session()
statement = None
Expand All @@ -318,7 +341,9 @@ def get_random(self):
"""
Returns a random statement from the database
"""
from chatterbot.ext.sqlalchemy_app.models import Statement
import random

Statement = self.get_model('statement')

session = self.Session()
count = self.count()
Expand Down
Loading