diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 8316b0a81..0840c8d95 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,6 +2,7 @@ import os import sys from .conversation import Statement, Response +from .utils import print_progress_bar class Trainer(object): @@ -79,7 +80,9 @@ def train(self, conversation): """ previous_statement_text = None - for text in conversation: + for conversation_count, text in enumerate(conversation): + print_progress_bar("List Trainer", conversation_count + 1, len(conversation)) + statement = self.get_or_create(text) if previous_statement_text: @@ -115,8 +118,15 @@ def train(self, *corpus_paths): corpora = self.corpus.load_corpus(corpus_path) - for corpus in corpora: - for conversation in corpus: + corpus_files = self.corpus.list_corpus_files(corpus_path) + for corpus_count, corpus in enumerate(corpora): + for conversation_count, conversation in enumerate(corpus): + print_progress_bar( + str(os.path.basename(corpus_files[corpus_count])) + " Training", + conversation_count + 1, + len(corpus) + ) + previous_statement_text = None for text in conversation: diff --git a/chatterbot/utils.py b/chatterbot/utils.py index 6b51ba1f7..3efc0ba34 100644 --- a/chatterbot/utils.py +++ b/chatterbot/utils.py @@ -193,3 +193,32 @@ def generate_strings(total_strings, string_length=20): ) statements.append(text) return statements + + +def print_progress_bar(description, iteration_counter, total_items, progress_bar_length=20): + """ + Print progress bar + :param description: Training description + :type description: str + + :param iteration_counter: Incremental counter + :type iteration_counter: int + + :param total_items: total number items + :type total_items: int + + :param progress_bar_length: Progress bar length + :type progress_bar_length: int + + :returns: void + :rtype: void + """ + import sys + + percent = float(iteration_counter) / total_items + hashes = '#' * int(round(percent * progress_bar_length)) + spaces = ' ' * (progress_bar_length - len(hashes)) + sys.stdout.write("\r{0}: [{1}] {2}%".format(description, hashes + spaces, int(round(percent * 100)))) + sys.stdout.flush() + if total_items == iteration_counter: + print("\r")