diff --git a/farm/train.py b/farm/train.py index 962150f35..ec85b7233 100644 --- a/farm/train.py +++ b/farm/train.py @@ -240,6 +240,9 @@ def train(self): If trainer evaluates the model with a test set the result of the evaluation is stored in ``test_result``. + + :return: Returns the model after training. When you do ``early_stopping`` + with a ``save_dir`` the best model is loaded and returned. """ # connect the prediction heads with the right output from processor @@ -352,8 +355,8 @@ def train(self): if self.early_stopping and self.early_stopping.save_dir: logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir)) lm_name = self.model.language_model.name - model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name) - model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) + self.model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name) + self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) # Eval on test set if self.evaluator_test: