Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for regression #440

Closed
rnditdev opened this issue Feb 1, 2019 · 13 comments
Closed

Support for regression #440

rnditdev opened this issue Feb 1, 2019 · 13 comments

Comments

@rnditdev
Copy link

rnditdev commented Feb 1, 2019

Hello,
I'd like to tackle regression problems with flair, I think it should be a matter of slightly altering the last layer of the model and the loss function.
I could try to do the modifications, but jus asking if it's a planned feature, or maybe the community can point me to the right layer types to replace.

Best Regards.

@alanakbik
Copy link
Collaborator

Hello @rnditdev we are not currently planning to add regression in the near future, so if you would be interested in adding this we would greatly appreciate it!

You could add a new class in the models/ folder that is probably very similar to the TextClassifier just slimmed down a bit since there is no need to distinguish between single class and multi class problems. Like the TextClassifier, your class would probably inherit from flair.nn.Model. That would make it compatible with the ModelTrainer.

If you would like to give it a go, please let us know if you have any questions - we're happy to assist as much as we can!

@heukirne
Copy link
Contributor

heukirne commented Feb 16, 2019

I solvet it with this two workaround classes. Not very elegant, but solves my the problem.

  • TextRegressor:
    • dummy label_dictionary to decode to single dimension
    • change to MSELoss
    • get labels as float
import flair
import torch
import torch.nn as nn
from typing import List, Union
from flair.training_utils import clear_embeddings

class TextRegressor(flair.models.TextClassifier):
  
    def __init__(self,
                 document_embeddings: flair.embeddings.DocumentEmbeddings,
                 label_dictionary: flair.data.Dictionary,
                 multi_label: bool):

        super(TextRegressor, self).__init__(document_embeddings=document_embeddings, label_dictionary=flair.data.Dictionary(), multi_label=multi_label)

        self.loss_function = nn.MSELoss()

    def _labels_to_indices(self, sentences: List[Sentence]):
        indices = [
            torch.FloatTensor([int(label.value) for label in sentence.labels])
            for sentence in sentences
        ]

        vec = torch.cat(indices, 0)
        if torch.cuda.is_available():
            vec = vec.cuda()

        return vec
      
    def forward_labels_and_loss(self, sentences: Union[Sentence, List[Sentence]]) -> (List[List[float]], torch.tensor):
        scores = self.forward(sentences)
        loss = self._calculate_loss(scores, sentences)
        return scores, loss

    def predict(self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32) -> List[Sentence]:

        with torch.no_grad():
            if type(sentences) is Sentence:
                sentences = [sentences]

            filtered_sentences = self._filter_empty_sentences(sentences)

            batches = [filtered_sentences[x:x + mini_batch_size] for x in range(0, len(filtered_sentences), mini_batch_size)]

            for batch in batches:
                scores = self.forward(batch)

                for (sentence, score) in zip(batch, scores.tolist()):
                  sentence.labels = [Label(value=str(score[0]))]

                clear_embeddings(batch)

            return sentences
  • RegressorTrainer:
    • evaluate with MSE and MAE
from sklearn.metrics import mean_squared_error, mean_absolute_error
from flair.training_utils import Metric, EvaluationMetric, clear_embeddings
from pathlib import Path
import logging

log = logging.getLogger('flair')

class RegressorTrainer(flair.trainers.ModelTrainer):
  
    @staticmethod
    def _evaluate_text_regressor(model: flair.nn.Model,
                                  sentences: List[Sentence],
                                  eval_mini_batch_size: int = 32,
                                  embeddings_in_memory: bool = False) -> (dict, float):

        with torch.no_grad():
            eval_loss = 0

            batches = [sentences[x:x + eval_mini_batch_size] for x in
                       range(0, len(sentences), eval_mini_batch_size)]

            metric = {}

            for batch in batches:
              
                scores, loss = model.forward_labels_and_loss(batch)

                indices = []
                for sentence in batch:
                  for label in sentence.labels:
                    indices.append(torch.FloatTensor([int(label.value)]))

                true_values = torch.cat(indices, 0).cuda()
                
                clear_embeddings(batch, also_clear_word_embeddings=not embeddings_in_memory)

                eval_loss += loss

                metric['mae'] = mean_absolute_error(scores.tolist(), true_values.tolist())
                metric['mse'] = mean_squared_error(scores.tolist(), true_values.tolist())

            eval_loss /= len(sentences)

            return metric, eval_loss

  
    def _calculate_evaluation_results_for(self,
                                          dataset_name: str,
                                          dataset: List[Sentence],
                                          evaluation_metric: EvaluationMetric,
                                          embeddings_in_memory: bool,
                                          eval_mini_batch_size: int,
                                          out_path: Path = None):

        metric, loss = RegressorTrainer._evaluate_text_regressor(self.model, dataset, eval_mini_batch_size=eval_mini_batch_size,
                                             embeddings_in_memory=embeddings_in_memory)

        mae = metric['mae']
        mse = metric['mse']

        log.info(f'{dataset_name:<5}: loss {loss:.8f} - mse {mse:.4f} - mae {mae:.4f}')

        return Metric('Evaluation'), loss

How to use it:

regressor = TextRegressor(document_embeddings, label_dictionary=label_dict, multi_label=False)
trainer = RegressorTrainer(regressor, corpus)

Has a lot of improvements to make it right.

Hope it helps.

@alanakbik
Copy link
Collaborator

alanakbik commented Feb 19, 2019

Hello @heukirne thanks for posting this - this is really helpful and could maybe be integrated into Flair.

Just a quick question: How do you load the corpus in this example? Do you use the Label field of a sentence to store the value you wish to predict? Could you provide an example in which data is loaded the the regressor trained?

Edit: And sorry for the late reply!

@heukirne
Copy link
Contributor

Hi @alanakbik, yes, I'll create a PR for these classes next week with all changes needed to regression work properly with Flair.

At this time I use the same Label structure to set the value. It store as a string but I cast to float every time I use it. This is how I load my corpus:

from flair.data import Sentence,  Label, TaggedCorpus

dfCharlson = pd.read_csv('charlson.csv')
sentences = []
for idx in dfCharlson.index:
  data = dfCharlson.iloc[idx]
  sentence = Sentence(data.text, labels=[Label(value=str(data.target))], use_tokenizer=True)
  sentences.append(sentence)

corpus = TaggedCorpus([sentences[t] for t in train], [sentences[d] for d in dev], [sentences[e] for e in test])

train, dev, test are arrays created using sklearn.model_selection.StratifiedKFold

@alanakbik
Copy link
Collaborator

Ah great - looking forward to the PR :) I'd be very interested to see how well regression works!

@alanakbik alanakbik mentioned this issue Feb 24, 2019
5 tasks
heukirne added a commit to nlp-pucrs/flair that referenced this issue Feb 24, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Feb 25, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Feb 25, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Feb 25, 2019
TODO: still need a self-contained MSE and MAE metric
heukirne added a commit to nlp-pucrs/flair that referenced this issue Feb 26, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
add mean squared error as default for regression
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
TODO: still need a self-contained MSE and MAE metric
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
add mean squared error as default for regression
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
heukirne added a commit to nlp-pucrs/flair that referenced this issue Mar 6, 2019
@alanakbik
Copy link
Collaborator

PR with regression implementation pushed to master!

alanakbik pushed a commit that referenced this issue May 2, 2019
alanakbik pushed a commit that referenced this issue May 7, 2019
…task

GH-440: added WASSA 2017 emotion intensity task
@leo-gan
Copy link

leo-gan commented Dec 16, 2019

Can you, please, add a sample/tutorial for the regression problem? Thanks!

@heukirne
Copy link
Contributor

heukirne commented Dec 17, 2019

There is as example in tests/test_text_regressor.py
;)

@aditya-malte
Copy link

aditya-malte commented Jan 11, 2021

Hi,
Firstly, I'd like to thank you for the really great work on the regression task.
Although, It's a bit unclear what the original data format should be (for Fasttext format), and also how the model is being trained for this as a regression task. I tried looking at the source code to find "label_to_indices", does this mean that we are passing int/float values as classes to the model? (of course the MSE loss makes that seem unlikely) Some clarity on these two would be very helpful.
Thanks

@heukirne
Copy link
Contributor

Hello @aditya-malte , here has some exemples of regression format
https://github.com/flairNLP/flair/tree/master/tests/resources/tasks/regression

@aditya-malte
Copy link

Hi,
Thanks a lot for the quick response! I checked it out. By the way, how do we eventually pass it to the model, through something like binning to break continuous values into classes (this is what the source code made me think)? Or proper regression?

@heukirne
Copy link
Contributor

Hi, @aditya-malte .
This neural network has a numeric output, so it works as a regression model, learning about the numeric outputs. ;)

@aditya-malte
Copy link

Oh, ok. That’s great then! Some of the function names had me worried that we were binning and passing it as a classification task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants