diff --git a/flair/trainers/trainer_regression.py b/flair/trainers/trainer_regression.py index 5e39022adf..648b2bc412 100644 --- a/flair/trainers/trainer_regression.py +++ b/flair/trainers/trainer_regression.py @@ -15,6 +15,46 @@ class RegressorTrainer(flair.trainers.ModelTrainer): + def train(self, + base_path: Union[Path, str], + evaluation_metric: EvaluationMetric = EvaluationMetric.MEAN_SQUARED_ERROR, + learning_rate: float = 0.1, + mini_batch_size: int = 32, + eval_mini_batch_size: int = None, + max_epochs: int = 100, + anneal_factor: float = 0.5, + patience: int = 3, + anneal_against_train_loss: bool = True, + train_with_dev: bool = False, + monitor_train: bool = False, + embeddings_in_memory: bool = True, + checkpoint: bool = False, + save_final_model: bool = True, + anneal_with_restarts: bool = False, + test_mode: bool = False, + param_selection_mode: bool = False, + **kwargs + ) -> dict: + + return super(RegressorTrainer, self).train( + base_path=base_path, + evaluation_metric=evaluation_metric, + learning_rate=learning_rate, + mini_batch_size=mini_batch_size, + eval_mini_batch_size=eval_mini_batch_size, + max_epochs=max_epochs, + anneal_factor=anneal_factor, + patience=patience, + anneal_against_train_loss=anneal_against_train_loss, + train_with_dev=train_with_dev, + monitor_train=monitor_train, + embeddings_in_memory=embeddings_in_memory, + checkpoint=checkpoint, + save_final_model=save_final_model, + anneal_with_restarts=anneal_with_restarts, + test_mode=test_mode, + param_selection_mode=param_selection_mode) + @staticmethod def _evaluate_text_regressor(model: flair.nn.Model, sentences: List[Sentence], diff --git a/flair/training_utils.py b/flair/training_utils.py index 0d72d65177..929d20d447 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -140,6 +140,7 @@ class EvaluationMetric(Enum): MICRO_F1_SCORE = 'micro-average f1-score' MACRO_ACCURACY = 'macro-average accuracy' MACRO_F1_SCORE = 'macro-average f1-score' + MEAN_SQUARED_ERROR = 'mean squared error' class WeightExtractor(object):