From 2ae86db869c6bae0a4d5836ad014e00005610109 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 27 Jan 2018 10:08:43 -0500 Subject: [PATCH] Use preprocessors during training --- chatterbot/chatterbot.py | 3 ++ chatterbot/trainers.py | 31 +++++++++++++++---- .../test_training_preprocessors.py | 30 ++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) create mode 100644 tests/training_tests/test_training_preprocessors.py diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index c12bc7563..2a5049dff 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -160,6 +160,9 @@ def set_trainer(self, training_class, **kwargs): :param \**kwargs: Any parameters that should be passed to the training class. """ + if 'chatbot' not in kwargs: + kwargs['chatbot'] = self + self.trainer = training_class(self.storage, **kwargs) @property diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index e7f22ccf2..1f634d1fb 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,7 +2,7 @@ import os import sys from .conversation import Statement, Response -from .utils import print_progress_bar +from . import utils class Trainer(object): @@ -11,13 +11,28 @@ class Trainer(object): """ def __init__(self, storage, **kwargs): + self.chatbot = kwargs.get('chatbot') self.storage = storage self.logger = logging.getLogger(__name__) self.show_training_progress = kwargs.get('show_training_progress', True) + def get_preprocessed_statement(self, input_statement): + """ + Preprocess the input statement. + """ + + # The chatbot is optional to prevent backwards-incompatible changes + if not self.chatbot: + return input_statement + + for preprocessor in self.chatbot.preprocessors: + input_statement = preprocessor(self, input_statement) + + return input_statement + def train(self, *args, **kwargs): """ - This class must be overridden by a class the inherits from 'Trainer'. + This method must be overridden by a child class. """ raise self.TrainerInitializationException() @@ -26,10 +41,14 @@ def get_or_create(self, statement_text): Return a statement if it exists. Create and return the statement if it does not exist. """ - statement = self.storage.find(statement_text) + temp_statement = self.get_preprocessed_statement( + Statement(text=statement_text) + ) + + statement = self.storage.find(temp_statement.text) if not statement: - statement = Statement(statement_text) + statement = Statement(temp_statement.text) return statement @@ -83,7 +102,7 @@ def train(self, conversation): for conversation_count, text in enumerate(conversation): if self.show_training_progress: - print_progress_bar( + utils.print_progress_bar( 'List Trainer', conversation_count + 1, len(conversation) ) @@ -128,7 +147,7 @@ def train(self, *corpus_paths): for conversation_count, conversation in enumerate(corpus): if self.show_training_progress: - print_progress_bar( + utils.print_progress_bar( str(os.path.basename(corpus_files[corpus_count])) + ' Training', conversation_count + 1, len(corpus) diff --git a/tests/training_tests/test_training_preprocessors.py b/tests/training_tests/test_training_preprocessors.py new file mode 100644 index 000000000..55888d1c4 --- /dev/null +++ b/tests/training_tests/test_training_preprocessors.py @@ -0,0 +1,30 @@ +from tests.base_case import ChatBotTestCase +from chatterbot import trainers +from chatterbot import preprocessors + + +class PreprocessorTrainingTests(ChatBotTestCase): + """ + These tests are designed to ensure that preprocessors + will be used to process the input the chat bot is given + during the training process. + """ + + def test_training_cleans_whitespace(self): + """ + Test that the ``clean_whitespace`` preprocessor is used during + the training process. + """ + self.chatbot.preprocessors = [preprocessors.clean_whitespace] + self.chatbot.set_trainer(trainers.ListTrainer) + + self.chatbot.train([ + 'Can I help you with anything?', + 'No, I think I am all set.', + 'Okay, have a nice day.', + 'Thank you, you too.' + ]) + + response = self.chatbot.get_response('Can I help you with anything?') + + self.assertEqual(response.text, 'No, I think I am all set.')