Skip to content

Commit

Permalink
Create conversation object
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Dec 23, 2016
1 parent 50a91ab commit 4da193c
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 13 deletions.
2 changes: 1 addition & 1 deletion chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, name, **kwargs):
self.trainer = TrainerClass(self.storage, **kwargs)
self.training_data = kwargs.get('training_data')

self.conversation_sessions = ConversationSessionManager()
self.conversation_sessions = ConversationSessionManager(self.storage)
self.default_session = self.conversation_sessions.new()

self.logger = kwargs.get('logger', logging.getLogger(__name__))
Expand Down
3 changes: 3 additions & 0 deletions chatterbot/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .statement import Statement
from .response import Response
from .session import Session

Conversation = Session
41 changes: 38 additions & 3 deletions chatterbot/conversation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,64 @@
from chatterbot.queues import ResponseQueue


class StatementManager(object):
"""
Provides methods for adding and retrieving statements
for this conversation in the database.
"""

def __init__(self, storage, conversation_id):
self.storage = storage
self.conversation_id = conversation_id

def all(self):
"""
Return all statements in the conversation.
"""
return self.storage.filter(conversation__id=self.conversation_id)

def add(self, statement):
"""
Add a statement to the conversation.
"""
statement.conversation_id = self.conversation_id
self.storage.update(statement)

class Session(object):
"""
A single chat session.
TODO: Rename to Conversation
"""

def __init__(self):
def __init__(self, storage):
super(Session, self).__init__()

self.storage = storage

# A unique identifier for the chat session
self.uuid = uuid.uuid1()
self.id_string = str(self.uuid)
self.id = str(self.uuid)

# The last 10 statement inputs and outputs
self.conversation = ResponseQueue(maxsize=10)
self.statements = StatementManager(self.storage, self.id)


class ConversationSessionManager(object):
"""
Object to hold and manage multiple chat sessions.
"""

def __init__(self):
def __init__(self, storage):
self.storage = storage
self.sessions = {}

def new(self):
"""
Add a new chat session.
"""
session = Session()
session = self.storage.create_conversation()

self.sessions[session.id_string] = session

Expand All @@ -47,6 +78,8 @@ def update(self, session_id, conversance):
session_id = str(session_id)
if session_id in self.sessions:
self.sessions[session_id].conversation.append(conversance)
for statement in conversance:
self.sessions[session_id].statements.add(statement)

def get_default(self):
"""
Expand All @@ -65,3 +98,5 @@ def update_default(self, conversance):
if self.sessions:
session_id = list(self.sessions.keys())[0]
self.sessions[session_id].conversation.append(conversance)
for statement in conversance:
self.sessions[session_id].statements.add(statement)
4 changes: 4 additions & 0 deletions chatterbot/conversation/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Statement(object):

def __init__(self, text, **kwargs):
self.text = text
self.conversation_id = kwargs.pop('conversation_id', None)
self.in_response_to = kwargs.pop('in_response_to', [])
self.extra_data = kwargs.pop('extra_data', {})

Expand Down Expand Up @@ -124,6 +125,9 @@ def serialize(self):
data['in_response_to'] = []
data['extra_data'] = self.extra_data

if self.conversation_id:
data['conversation_id'] = self.conversation_id

for response in self.in_response_to:
data['in_response_to'].append(response.serialize())

Expand Down
20 changes: 19 additions & 1 deletion chatterbot/ext/django_chatterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class Statement(models.Model):
"""
A short (<255) character message that is part of a dialog.
A message that is part of a dialog of conversation.
"""

text = models.CharField(
Expand All @@ -13,6 +13,15 @@ class Statement(models.Model):
max_length=255
)

conversation = models.ForeignKey(
'Conversation',
related_name='statements',
blank=True,
null=True
)

time_created = models.DateTimeField(auto_now_add=True)

extra_data = models.CharField(max_length=500)

def __str__(self):
Expand Down Expand Up @@ -138,3 +147,12 @@ def __str__(self):
statement if len(statement) <= 20 else statement[:17] + '...',
response if len(response) <= 40 else response[:37] + '...'
)


class Conversation(models.Model):
"""
A sequence of statements representing a conversation.
"""

def __str__(self):
return str(self.id)
13 changes: 13 additions & 0 deletions chatterbot/storage/storage_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def generate_base_query(self, chatterbot, session_id):
for filter_instance in chatterbot.filters:
self.base_query = filter_instance.filter_selection(chatterbot, session_id)

def create_conversation(self):
"""
Returns a new storage-aware conversation instance.
"""
import os

if 'DJANGO_SETTINGS_MODULE' in os.environ:
from chatterbot.ext.django_chatterbot.models import Conversation
return Conversation.objects.create()
else:
from chatterbot.conversation import Conversation
return Conversation(self)

def count(self):
"""
Return the number of entries in the database.
Expand Down
17 changes: 9 additions & 8 deletions tests/test_sessions.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from unittest import TestCase
from chatterbot.conversation.session import Session, ConversationSessionManager
from chatterbot.conversation import Conversation
from chatterbot.conversation.session import ConversationSessionManager
from .base_case import ChatBotTestCase


class SessionTestCase(TestCase):
class SessionTestCase(ChatBotTestCase):

def test_id_string(self):
session = Session()
self.assertEqual(str(session.uuid), session.id_string)
conversation = Conversation(self.chatbot.storage)
self.assertEqual(str(conversation.uuid), conversation.id_string)


class ConversationSessionManagerTestCase(TestCase):
class ConversationSessionManagerTestCase(ChatBotTestCase):

def setUp(self):
super(ConversationSessionManagerTestCase, self).setUp()
self.manager = ConversationSessionManager()
self.manager = ConversationSessionManager(self.chatbot.storage)

def test_new(self):
session = self.manager.new()

self.assertTrue(isinstance(session, Session))
self.assertTrue(isinstance(session, Conversation))
self.assertIn(session.id_string, self.manager.sessions)
self.assertEqual(session, self.manager.sessions[session.id_string])

Expand Down

0 comments on commit 4da193c

Please sign in to comment.