Skip to content

Commit

Permalink
Rename session variables to conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Mar 15, 2017
1 parent 60575c6 commit 04d2449
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 124 deletions.
46 changes: 23 additions & 23 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(self, name, **kwargs):
self.trainer = TrainerClass(self.storage, **kwargs)
self.training_data = kwargs.get('training_data')

self.conversation_sessions = ConversationManager(self.storage)
self.default_session = None
self.conversations = ConversationManager(self.storage)
self.default_conversation = None

self.logger = kwargs.get('logger', logging.getLogger(__name__))

Expand All @@ -94,7 +94,7 @@ def initialize(self):
nltk_download_corpus('tokenizers/punkt')
nltk_download_corpus('sentiment/vader_lexicon')

def get_response(self, input_item, session_id=None):
def get_response(self, input_item, conversation_id=None):
"""
Return the bot's response based on the input.
Expand All @@ -108,45 +108,45 @@ def get_response(self, input_item, session_id=None):
for preprocessor in self.preprocessors:
input_statement = preprocessor(self, input_statement)

if session_id:
session = self.conversation_sessions.get(session_id)
if conversation_id:
conversation = self.conversations.get(conversation_id)

if not session:
session = self.get_or_create_default_conversation()
if not conversation:
conversation = self.get_or_create_default_conversation()
else:
session = self.get_or_create_default_conversation()
conversation = self.get_or_create_default_conversation()

statement, response = self.generate_response(input_statement, session.id)
statement, response = self.generate_response(input_statement, conversation.id)

# Learn that the user's input was a valid response to the chat bot's previous output
previous_statement = session.get_last_response_statement()
previous_statement = conversation.get_last_response_statement()

self.learn_response(statement, previous_statement, session)
self.learn_response(statement, previous_statement, conversation)

if not self.read_only:
response.save()
session.statements.add(response)
conversation.statements.add(response)

# Process the response output with the output adapter
return self.output.process_response(response, session.id)
return self.output.process_response(response, conversation.id)

def generate_response(self, input_statement, session_id):
def generate_response(self, input_statement, conversation_id):
"""
Return a response based on a given input statement.
"""
self.storage.generate_base_query(self, session_id)
self.storage.generate_base_query(self, conversation_id)

# Select a response to the input statement
response = self.logic.process(input_statement)

return input_statement, response

def learn_response(self, statement, previous_statement, session=None):
def learn_response(self, statement, previous_statement, conversation=None):
"""
Learn that the statement provided is a valid response.
"""
if not session:
session = self.get_or_create_default_conversation()
if not conversation:
conversation = self.get_or_create_default_conversation()

if previous_statement:
statement.add_response(
Expand All @@ -160,7 +160,7 @@ def learn_response(self, statement, previous_statement, session=None):
# Save the statement after selecting a response
if not self.read_only:
self.storage.update(statement)
session.statements.add(statement)
conversation.statements.add(statement)

def set_trainer(self, training_class, **kwargs):
"""
Expand All @@ -175,14 +175,14 @@ def set_trainer(self, training_class, **kwargs):

def get_or_create_default_conversation(self):
"""
Get the default conversation session if it exists.
Get the default conversation if it exists.
Otherwise create a new conversation.
This is a lazy function designed to only create the conversation if
a statement exists for it.
"""
if not self.default_session:
self.default_session = self.storage.Conversation.objects.create()
return self.default_session
if not self.default_conversation:
self.default_conversation = self.storage.Conversation.objects.create()
return self.default_conversation

@property
def train(self):
Expand Down
15 changes: 7 additions & 8 deletions chatterbot/conversation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,17 @@ def all(self):
def add(self, statement):
self.statements.append(statement)
self.conversation.save()



class Conversation(ConversationModelMixin):
"""
A single chat session.
A conversation is an ordered collection of statements.
"""

objects = None

def __init__(self, **kwargs):
# A unique identifier for the chat session
# A unique identifier for the conversation
self.uuid = uuid.uuid1()
self.id = kwargs.get('id', str(self.uuid))

Expand All @@ -86,25 +85,25 @@ def save(self):

class ConversationManager(object):
"""
Object to hold and manage multiple chat sessions.
Object to hold and manage conversation.
"""

def __init__(self, storage):
self.storage = storage

def create(self):
"""
Add a new chat session.
Create a new conversation.
"""
conversation = self.storage.Conversation()
conversation.save()
return conversation

def get(self, session_id, default=None):
def get(self, conversation_id, default=None):
"""
Return a session given a unique identifier.
Return a conversation given a unique identifier.
"""
results = self.storage.filter(self.storage.Conversation, id=session_id)
results = self.storage.filter(self.storage.Conversation, id=conversation_id)
if results:
return results[0]
else:
Expand Down
2 changes: 0 additions & 2 deletions chatterbot/ext/django_chatterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ class Statement(StatementModelMixin, models.Model):
phrase that someone can say.
"""

collection_name = 'statements'

text = models.CharField(
blank=False,
null=False,
Expand Down
26 changes: 13 additions & 13 deletions chatterbot/ext/django_chatterbot/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ def validate(self, data):
if 'text' not in data:
raise ValidationError('The attribute "text" is required.')

def get_chat_session(self, request):
def get_conversation(self, request):
"""
Return the current session for the chat if one exists.
Create a new session if one does not exist.
Return the current conversation for the chat if one exists.
Create a new conversation if one does not exist.
"""
chat_session_id = request.session.get('chat_session_id', None)
chat_session = self.chatterbot.conversation_sessions.get(chat_session_id, None)
conversation_id = request.session.get('chat_conversation_id', None)
conversation = self.chatterbot.conversations.get(conversation_id, None)

if not chat_session:
chat_session = self.chatterbot.conversation_sessions.create()
request.session['chat_session_id'] = chat_session.id
if not conversation:
conversation = self.chatterbot.conversations.create()
request.session['chat_conversation_id'] = conversation.id

return chat_session
return conversation


class ChatterBotView(ChatterBotViewMixin, View):
Expand Down Expand Up @@ -64,9 +64,9 @@ def post(self, request, *args, **kwargs):
extra_data = input_data['extra_data']
input_data['extra_data'] = json.dumps(extra_data)

chat_session = self.get_chat_session(request)
conversation = self.get_conversation(request)

response = self.chatterbot.get_response(input_data, chat_session.id)
response = self.chatterbot.get_response(input_data, conversation.id)
response_data = response.serialize()

return JsonResponse(response_data, status=200)
Expand All @@ -75,12 +75,12 @@ def get(self, request, *args, **kwargs):
"""
Return data corresponding to the current conversation.
"""
chat_session = self.get_chat_session(request)
conversation = self.get_conversation(request)

data = {
'detail': 'You should make a POST request to this endpoint.',
'name': self.chatterbot.name,
'conversation': self._serialize_conversation(chat_session)
'conversation': self._serialize_conversation(conversation)
}

# Return a method not allowed response
Expand Down
10 changes: 5 additions & 5 deletions chatterbot/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Filter(object):
filters should be subclassed.
"""

def filter_selection(self, chatterbot, session_id):
def filter_selection(self, chatterbot, conversation_id):
"""
Because this is the base filter class, this method just
returns the storage adapter's base query. Other filters
Expand All @@ -24,19 +24,19 @@ class RepetitiveResponseFilter(Filter):
a chat bot from repeating statements that it has recently said.
"""

def filter_selection(self, chatterbot, session_id):
def filter_selection(self, chatterbot, conversation_id):

session = chatterbot.conversation_sessions.get(session_id)
conversation = chatterbot.conversations.get(conversation_id)

# Check if a conversation of some length exists
if not session.statements.exists():
if not conversation.statements.exists():
return chatterbot.storage.base_query

text_of_recent_responses = []

skip = True

for statement in session.statements.all():
for statement in conversation.statements.all():

# Skip every other statement to only filter out the bot's responses
if skip:
Expand Down
20 changes: 11 additions & 9 deletions chatterbot/input/hipchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class HipChat(InputAdapter):
def __init__(self, **kwargs):
super(HipChat, self).__init__(**kwargs)

self.hipchat_host = kwargs.get("hipchat_host")
self.hipchat_access_token = kwargs.get("hipchat_access_token")
self.hipchat_room = kwargs.get("hipchat_room")
self.session_id = str(self.chatbot.default_session.uuid)
self.hipchat_host = kwargs.get('hipchat_host')
self.hipchat_access_token = kwargs.get('hipchat_access_token')
self.hipchat_room = kwargs.get('hipchat_room')
self.conversation = self.chatbot.default_conversation

authorization_header = "Bearer {}".format(self.hipchat_access_token)
authorization_header = 'Bearer {}'.format(self.hipchat_access_token)

self.headers = {
'Authorization': authorization_header,
Expand Down Expand Up @@ -81,10 +81,12 @@ def process_input(self, statement):
"""
new_message = False

input_statement = self.chatbot.conversation_sessions.get(
self.session_id).conversation.get_last_input_statement()
response_statement = self.chatbot.conversation_sessions.get(
self.session_id).conversation.get_last_response_statement()
input_statement = self.chatbot.conversations.get(
self.conversation.id
).conversation.get_last_input_statement()
response_statement = self.chatbot.conversations.get(
self.conversation.id
).conversation.get_last_response_statement()

if input_statement:
last_message_id = input_statement.extra_data.get(
Expand Down
2 changes: 1 addition & 1 deletion chatterbot/logic/multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def process(self, statement):
'Make sure that statement.confidence is being set.'.format(adapter.class_name),
DeprecationWarning
)
output = output[1]
output = output[1]

results.append((output.confidence, output, ))

Expand Down
2 changes: 1 addition & 1 deletion chatterbot/output/gitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def send_message(self, text):
self._validate_status_code(response)
return response.json()

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
self.send_message(statement.text)
return statement

Expand Down
4 changes: 2 additions & 2 deletions chatterbot/output/hipchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def reply_to_message(self):
"""
raise self.AdapterMethodNotImplementedError()

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
data = self.send_message(self.hipchat_room, statement.text)

# Update the output statement with the message id
self.chatbot.conversation_sessions.get(session_id).conversation[-1][1].add_extra_data(
self.chatbot.conversations.get(conversation_id).conversation[-1][1].add_extra_data(
'hipchat_message_id', data['id']
)

Expand Down
2 changes: 1 addition & 1 deletion chatterbot/output/mailgun.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def send_message(self, subject, text, from_address, recipients):
'text': text
})

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
"""
Send the response statement as an email.
"""
Expand Down
2 changes: 1 addition & 1 deletion chatterbot/output/microsoft.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def send_message(self, conversation_id, message):
# Microsoft return 204 on operation succeeded and no content was returned.
return self.get_most_recent_message()

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
data = self.send_message(self.conversation_id, statement.text)
self.logger.info('processing user response {}'.format(data))
return statement
Expand Down
4 changes: 2 additions & 2 deletions chatterbot/output/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ class OutputAdapter(Adapter):
functionality, such as delivering a response to an API endpoint.
"""

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
"""
Override this method in a subclass to implement customized functionality.
:param statement: The statement that the chat bot has produced in response to some input.
:param session_id: The unique id of the current chat session.
:param conversation_id: The unique id of the current conversation.
:returns: The response statement.
"""
Expand Down
2 changes: 1 addition & 1 deletion chatterbot/output/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TerminalAdapter(OutputAdapter):
communicate through the terminal.
"""

def process_response(self, statement, session_id=None):
def process_response(self, statement, conversation_id=None):
"""
Print the response to the user's input.
"""
Expand Down
4 changes: 2 additions & 2 deletions chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def __init__(self, base_query=None, *args, **kwargs):

self.Response = Response

def generate_base_query(self, chatterbot, session_id):
def generate_base_query(self, chatterbot, conversation_id):
"""
Create a base query for the storage adapter.
"""
if self.adapter_supports_queries:
for filter_instance in chatterbot.filters:
self.base_query = filter_instance.filter_selection(chatterbot, session_id)
self.base_query = filter_instance.filter_selection(chatterbot, conversation_id)

def count(self):
"""
Expand Down
Loading

0 comments on commit 04d2449

Please sign in to comment.