Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow chat bot to track individual conversations #513

Closed
wants to merge 13 commits into from
Closed
53 changes: 35 additions & 18 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ChatBot(object):
"""

def __init__(self, name, **kwargs):
from .conversation.session import ConversationSessionManager
from .conversation.session import ConversationManager
from .logic import MultiLogicAdapter

self.name = name
Expand Down Expand Up @@ -71,8 +71,8 @@ def __init__(self, name, **kwargs):
self.trainer = TrainerClass(self.storage, **kwargs)
self.training_data = kwargs.get('training_data')

self.conversation_sessions = ConversationSessionManager()
self.default_session = self.conversation_sessions.new()
self.conversation_sessions = ConversationManager(self.storage)
self.default_session = None

self.logger = kwargs.get('logger', logging.getLogger(__name__))

Expand Down Expand Up @@ -102,27 +102,33 @@ def get_response(self, input_item, session_id=None):
:returns: A response to the input.
:rtype: Statement
"""
if not session_id:
session_id = str(self.default_session.uuid)

input_statement = self.input.process_input_statement(input_item)

# Preprocess the input statement
for preprocessor in self.preprocessors:
input_statement = preprocessor(self, input_statement)

statement, response = self.generate_response(input_statement, session_id)
if session_id:
session = self.conversation_sessions.get(session_id)

if not session:
session = self.get_or_create_default_conversation()
else:
session = self.get_or_create_default_conversation()

statement, response = self.generate_response(input_statement, session.id)

# Learn that the user's input was a valid response to the chat bot's previous output
previous_statement = self.conversation_sessions.get(
session_id
).conversation.get_last_response_statement()
self.learn_response(statement, previous_statement)
previous_statement = session.get_last_response_statement()

self.learn_response(statement, previous_statement, session)

self.conversation_sessions.update(session_id, (statement, response, ))
if not self.read_only:
response.save()
session.statements.add(response)

# Process the response output with the output adapter
return self.output.process_response(response, session_id)
return self.output.process_response(response, session.id)

def generate_response(self, input_statement, session_id):
"""
Expand All @@ -135,16 +141,15 @@ def generate_response(self, input_statement, session_id):

return input_statement, response

def learn_response(self, statement, previous_statement):
def learn_response(self, statement, previous_statement, session=None):
"""
Learn that the statement provided is a valid response.
"""
from .conversation import Response
if not session:
session = self.get_or_create_default_conversation()

if previous_statement:
statement.add_response(
Response(previous_statement.text)
)
statement.in_response_to = previous_statement
self.logger.info('Adding "{}" as a response to "{}"'.format(
statement.text,
previous_statement.text
Expand All @@ -153,6 +158,7 @@ def learn_response(self, statement, previous_statement):
# Save the statement after selecting a response
if not self.read_only:
self.storage.update(statement)
session.statements.add(statement)

def set_trainer(self, training_class, **kwargs):
"""
Expand All @@ -165,6 +171,17 @@ def set_trainer(self, training_class, **kwargs):
"""
self.trainer = training_class(self.storage, **kwargs)

def get_or_create_default_conversation(self):
"""
Get the default conversation session if it exists.
Otherwise create a new conversation.
This is a lazy function designed to only create the conversation if
a statement exists for it.
"""
if not self.default_session:
self.default_session = self.storage.Conversation.objects.create()
return self.default_session

@property
def train(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion chatterbot/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .statement import Statement
from .response import Response
from .session import Conversation
34 changes: 0 additions & 34 deletions chatterbot/conversation/response.py

This file was deleted.

110 changes: 86 additions & 24 deletions chatterbot/conversation/session.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,112 @@
import uuid
from chatterbot.queues import ResponseQueue


class Session(object):
class ConversationModelMixin(object):

collection_name = 'conversations'

pk_field = 'id'
fields = ('id', )

def get_last_response_statement(self):
"""
Return the last statement that was received.
"""
if self.statements.exists():
# Return the output statement
return self.statements.latest('id')
return None

def get_last_input_statement(self):
"""
Return the last response that was given.
"""
if self.statements.count() > 1:
# Return the input statement
return self.statements.all()[-2]
return None

def serialize(self):
statements = []

for statement in self.statements.all():
statements.append({'text': statement.text})

return {
'id': self.id,
'statements': statements
}


class StatementRelatedManager(object):

def __init__(self, conversation, statements):
self.statements = statements
self.conversation = conversation

def exists(self):
return len(self.statements) > 0

def count(self):
return len(self.statements)

def first(self):
return self.statements[0]

def latest(self, *args):
return self.statements[-1]

def all(self):
return self.statements

def add(self, statement):
self.statements.append(statement)
self.conversation.save()



class Conversation(ConversationModelMixin):
"""
A single chat session.
"""

def __init__(self):
objects = None

def __init__(self, **kwargs):
# A unique identifier for the chat session
self.uuid = uuid.uuid1()
self.id_string = str(self.uuid)
self.id = str(self.uuid)
self.id = kwargs.get('id', str(self.uuid))

# The last 10 statement inputs and outputs
self.conversation = ResponseQueue(maxsize=10)
statements = kwargs.get('statements', [])
self.statements = StatementRelatedManager(self, statements)

def save(self):
self.objects.storage.update(self)

class ConversationSessionManager(object):

class ConversationManager(object):
"""
Object to hold and manage multiple chat sessions.
"""

def __init__(self):
self.sessions = {}
def __init__(self, storage):
self.storage = storage

def new(self):
def create(self):
"""
Add a new chat session.
"""
session = Session()

self.sessions[session.id_string] = session

return session
conversation = self.storage.Conversation()
conversation.save()
return conversation

def get(self, session_id, default=None):
"""
Return a session given a unique identifier.
"""
return self.sessions.get(str(session_id), default)
results = self.storage.filter(self.storage.Conversation, id=session_id)
if results:
return results[0]
else:
return default

def update(self, session_id, conversance):
"""
Add a conversance to a given session if the session exists.
"""
session_id = str(session_id)
if session_id in self.sessions:
self.sessions[session_id].conversation.append(conversance)
Loading