diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index c12bc7563..2a5049dff 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -160,6 +160,9 @@ def set_trainer(self, training_class, **kwargs): :param \**kwargs: Any parameters that should be passed to the training class. """ + if 'chatbot' not in kwargs: + kwargs['chatbot'] = self + self.trainer = training_class(self.storage, **kwargs) @property diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 0a4b59a1f..beec74b58 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,7 +2,7 @@ import os import sys from .conversation import Statement, Response -from .utils import print_progress_bar +from . import utils class Trainer(object): @@ -11,12 +11,27 @@ class Trainer(object): """ def __init__(self, storage, **kwargs): + self.chatbot = kwargs.get('chatbot') self.storage = storage self.logger = logging.getLogger(__name__) + def get_preprocessed_statement(self, input_statement): + """ + Preprocess the input statement. + """ + + # The chatbot is optional to prevent backwards-incompatible changes + if not self.chatbot: + return input_statement + + for preprocessor in self.chatbot.preprocessors: + input_statement = preprocessor(self, input_statement) + + return input_statement + def train(self, *args, **kwargs): """ - This class must be overridden by a class the inherits from 'Trainer'. + This method must be overridden by a child class. """ raise self.TrainerInitializationException() @@ -25,10 +40,14 @@ def get_or_create(self, statement_text): Return a statement if it exists. Create and return the statement if it does not exist. """ - statement = self.storage.find(statement_text) + temp_statement = self.get_preprocessed_statement( + Statement(text=statement_text) + ) + + statement = self.storage.find(temp_statement.text) if not statement: - statement = Statement(statement_text) + statement = Statement(temp_statement.text) return statement @@ -81,7 +100,7 @@ def train(self, conversation): previous_statement_text = None for conversation_count, text in enumerate(conversation): - print_progress_bar("List Trainer", conversation_count + 1, len(conversation)) + utils.print_progress_bar("List Trainer", conversation_count + 1, len(conversation)) statement = self.get_or_create(text) @@ -121,7 +140,7 @@ def train(self, *corpus_paths): corpus_files = self.corpus.list_corpus_files(corpus_path) for corpus_count, corpus in enumerate(corpora): for conversation_count, conversation in enumerate(corpus): - print_progress_bar( + utils.print_progress_bar( str(os.path.basename(corpus_files[corpus_count])) + " Training", conversation_count + 1, len(corpus) diff --git a/tests/training_tests/test_training_preprocessors.py b/tests/training_tests/test_training_preprocessors.py new file mode 100644 index 000000000..55888d1c4 --- /dev/null +++ b/tests/training_tests/test_training_preprocessors.py @@ -0,0 +1,30 @@ +from tests.base_case import ChatBotTestCase +from chatterbot import trainers +from chatterbot import preprocessors + + +class PreprocessorTrainingTests(ChatBotTestCase): + """ + These tests are designed to ensure that preprocessors + will be used to process the input the chat bot is given + during the training process. + """ + + def test_training_cleans_whitespace(self): + """ + Test that the ``clean_whitespace`` preprocessor is used during + the training process. + """ + self.chatbot.preprocessors = [preprocessors.clean_whitespace] + self.chatbot.set_trainer(trainers.ListTrainer) + + self.chatbot.train([ + 'Can I help you with anything?', + 'No, I think I am all set.', + 'Okay, have a nice day.', + 'Thank you, you too.' + ]) + + response = self.chatbot.get_response('Can I help you with anything?') + + self.assertEqual(response.text, 'No, I think I am all set.')