Skip to content

Commit

Permalink
Use adapter statement objects in trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Mar 15, 2017
1 parent c059ff3 commit 60575c6
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 22 deletions.
15 changes: 9 additions & 6 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def get_response(self, input_item, session_id=None):

if session_id:
session = self.conversation_sessions.get(session_id)

if not session:
session = self.get_or_create_default_conversation()
else:
session = self.get_or_create_default_conversation()

Expand All @@ -118,12 +121,10 @@ def get_response(self, input_item, session_id=None):
# Learn that the user's input was a valid response to the chat bot's previous output
previous_statement = session.get_last_response_statement()

self.learn_response(statement, previous_statement)
self.learn_response(statement, previous_statement, session)

if not self.read_only:
statement.save()
response.save()
session.statements.add(statement)
session.statements.add(response)

# Process the response output with the output adapter
Expand All @@ -140,15 +141,16 @@ def generate_response(self, input_statement, session_id):

return input_statement, response

def learn_response(self, statement, previous_statement):
def learn_response(self, statement, previous_statement, session=None):
"""
Learn that the statement provided is a valid response.
"""
from .conversation import Response
if not session:
session = self.get_or_create_default_conversation()

if previous_statement:
statement.add_response(
Response(previous_statement.text)
self.storage.Response(previous_statement.text)
)
self.logger.info('Adding "{}" as a response to "{}"'.format(
statement.text,
Expand All @@ -158,6 +160,7 @@ def learn_response(self, statement, previous_statement):
# Save the statement after selecting a response
if not self.read_only:
self.storage.update(statement)
session.statements.add(statement)

def set_trainer(self, training_class, **kwargs):
"""
Expand Down
17 changes: 11 additions & 6 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,22 @@ def update(self, obj):

if obj.collection_name == 'statements':

statement, created = Statement.objects.get_or_create(text=obj.text)
statement.extra_data = getattr(obj, 'extra_data', '')
statement.save()
existing_statements = Statement.objects.filter(text=obj.text)

if existing_statements.exists():
statement = existing_statements.first()
statement.extra_data = getattr(obj, 'extra_data', '')
else:
obj.save()
statement = obj

for _response_statement in obj.response_statement_cache:

existing_statements = Statement.objects.filter(
existing_responses = Statement.objects.filter(
text=_response_statement.text
)
if existing_statements.exists():
response_statement = existing_statements.first()
if existing_responses.exists():
response_statement = existing_responses.first()
else:
response_statement = Statement(
text=_response_statement.text
Expand Down
7 changes: 7 additions & 0 deletions chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,24 @@ def __init__(self, base_query=None, *args, **kwargs):
django_project = __import__(os.environ['DJANGO_SETTINGS_MODULE'])
if django_project.settings.CHATTERBOT['use_django_models'] is True:
from chatterbot.ext.django_chatterbot.models import Statement, Conversation
from chatterbot.conversation.response import Response

self.Statement = Statement
self.Conversation = Conversation
self.Response = Response
else:
from chatterbot.conversation.statement import Statement
from chatterbot.conversation.session import Conversation, ConversationManager
from chatterbot.conversation.response import Response

self.Statement = Statement
self.Statement.storage = self

self.Conversation = Conversation
self.Conversation.objects = ConversationManager(self)

self.Response = Response

def generate_base_query(self, chatterbot, session_id):
"""
Create a base query for the storage adapter.
Expand Down
11 changes: 5 additions & 6 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from .conversation import Statement, Response


class Trainer(object):
Expand All @@ -25,7 +24,7 @@ def get_or_create(self, statement_text):
statement = self.storage.find(statement_text)

if not statement:
statement = Statement(statement_text)
statement = self.storage.Statement(text=statement_text)

return statement

Expand Down Expand Up @@ -83,7 +82,7 @@ def train(self, conversation):

if statement_history:
statement.add_response(
Response(statement_history[-1].text)
self.storage.Response(statement_history[-1].text)
)

statement_history.append(statement)
Expand Down Expand Up @@ -192,12 +191,12 @@ def get_statements(self):
self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word))
tweets = self.api.GetSearch(term=random_word, count=50)
for tweet in tweets:
statement = Statement(tweet.text)
statement = self.storage.Statement(text=tweet.text)

if tweet.in_reply_to_status_id:
try:
status = self.api.GetStatus(tweet.in_reply_to_status_id)
statement.add_response(Response(status.text))
statement.add_response(self.storage.Response(status.text))
statements.append(statement)
except TwitterError as error:
self.logger.warning(str(error))
Expand Down Expand Up @@ -347,7 +346,7 @@ def train(self):

if statement_history:
statement.add_response(
Response(statement_history[-1].text)
self.storage.Response(statement_history[-1].text)
)

statement_history.append(statement)
Expand Down
2 changes: 0 additions & 2 deletions examples/django_app/tests/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def test_get_conversation(self):
response = self.client.get(self.api_url)
data = self._get_json(response)

print data['conversation']

self.assertIn('conversation', data)
self.assertEqual(len(data['conversation']), 1)
self.assertIn('text', data['conversation'][0])
3 changes: 1 addition & 2 deletions tests_django/test_api_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_get_conversation(self):
data = self._get_json(response)

self.assertIn('conversation', data)
self.assertEqual(len(data['conversation']), 3)
self.assertEqual(len(data['conversation']), 2)
self.assertIn('text', data['conversation'][0])
self.assertIn('text', data['conversation'][1])
self.assertIn('text', data['conversation'][2])

0 comments on commit 60575c6

Please sign in to comment.