diff --git a/chatterbot/training/trainers.py b/chatterbot/training/trainers.py index 372206713..a5c9fb052 100644 --- a/chatterbot/training/trainers.py +++ b/chatterbot/training/trainers.py @@ -5,7 +5,6 @@ class Trainer(object): def __init__(self, storage, **kwargs): - self.kwargs = kwargs self.storage = storage self.corpus = Corpus() @@ -15,23 +14,27 @@ def train(self): class ListTrainer(Trainer): + def get_or_create(self, statement_text): + """ + Return a statement if it exists. + Create and return the statement if it does not exist. + """ + statement = self.storage.find(statement_text) + + if not statement: + statement = Statement(statement_text) + + return statement + def train(self, conversation): statement_history = [] for text in conversation: - statement = self.storage.find(text) + statement = self.get_or_create(text) - # Create the statement if a match was not found - if not statement: - statement = Statement(text) - - previous_statement = None if statement_history: - previous_statement = statement_history[-1] - - if previous_statement: statement.add_response( - Response(previous_statement.text) + Response(statement_history[-1].text) ) statement_history.append(statement)