-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for regression #440
Comments
Hello @rnditdev we are not currently planning to add regression in the near future, so if you would be interested in adding this we would greatly appreciate it! You could add a new class in the If you would like to give it a go, please let us know if you have any questions - we're happy to assist as much as we can! |
I solvet it with this two workaround classes. Not very elegant, but solves my the problem.
import flair
import torch
import torch.nn as nn
from typing import List, Union
from flair.training_utils import clear_embeddings
class TextRegressor(flair.models.TextClassifier):
def __init__(self,
document_embeddings: flair.embeddings.DocumentEmbeddings,
label_dictionary: flair.data.Dictionary,
multi_label: bool):
super(TextRegressor, self).__init__(document_embeddings=document_embeddings, label_dictionary=flair.data.Dictionary(), multi_label=multi_label)
self.loss_function = nn.MSELoss()
def _labels_to_indices(self, sentences: List[Sentence]):
indices = [
torch.FloatTensor([int(label.value) for label in sentence.labels])
for sentence in sentences
]
vec = torch.cat(indices, 0)
if torch.cuda.is_available():
vec = vec.cuda()
return vec
def forward_labels_and_loss(self, sentences: Union[Sentence, List[Sentence]]) -> (List[List[float]], torch.tensor):
scores = self.forward(sentences)
loss = self._calculate_loss(scores, sentences)
return scores, loss
def predict(self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32) -> List[Sentence]:
with torch.no_grad():
if type(sentences) is Sentence:
sentences = [sentences]
filtered_sentences = self._filter_empty_sentences(sentences)
batches = [filtered_sentences[x:x + mini_batch_size] for x in range(0, len(filtered_sentences), mini_batch_size)]
for batch in batches:
scores = self.forward(batch)
for (sentence, score) in zip(batch, scores.tolist()):
sentence.labels = [Label(value=str(score[0]))]
clear_embeddings(batch)
return sentences
from sklearn.metrics import mean_squared_error, mean_absolute_error
from flair.training_utils import Metric, EvaluationMetric, clear_embeddings
from pathlib import Path
import logging
log = logging.getLogger('flair')
class RegressorTrainer(flair.trainers.ModelTrainer):
@staticmethod
def _evaluate_text_regressor(model: flair.nn.Model,
sentences: List[Sentence],
eval_mini_batch_size: int = 32,
embeddings_in_memory: bool = False) -> (dict, float):
with torch.no_grad():
eval_loss = 0
batches = [sentences[x:x + eval_mini_batch_size] for x in
range(0, len(sentences), eval_mini_batch_size)]
metric = {}
for batch in batches:
scores, loss = model.forward_labels_and_loss(batch)
indices = []
for sentence in batch:
for label in sentence.labels:
indices.append(torch.FloatTensor([int(label.value)]))
true_values = torch.cat(indices, 0).cuda()
clear_embeddings(batch, also_clear_word_embeddings=not embeddings_in_memory)
eval_loss += loss
metric['mae'] = mean_absolute_error(scores.tolist(), true_values.tolist())
metric['mse'] = mean_squared_error(scores.tolist(), true_values.tolist())
eval_loss /= len(sentences)
return metric, eval_loss
def _calculate_evaluation_results_for(self,
dataset_name: str,
dataset: List[Sentence],
evaluation_metric: EvaluationMetric,
embeddings_in_memory: bool,
eval_mini_batch_size: int,
out_path: Path = None):
metric, loss = RegressorTrainer._evaluate_text_regressor(self.model, dataset, eval_mini_batch_size=eval_mini_batch_size,
embeddings_in_memory=embeddings_in_memory)
mae = metric['mae']
mse = metric['mse']
log.info(f'{dataset_name:<5}: loss {loss:.8f} - mse {mse:.4f} - mae {mae:.4f}')
return Metric('Evaluation'), loss How to use it: regressor = TextRegressor(document_embeddings, label_dictionary=label_dict, multi_label=False)
trainer = RegressorTrainer(regressor, corpus) Has a lot of improvements to make it right. Hope it helps. |
Hello @heukirne thanks for posting this - this is really helpful and could maybe be integrated into Flair. Just a quick question: How do you load the corpus in this example? Do you use the Label field of a sentence to store the value you wish to predict? Could you provide an example in which data is loaded the the regressor trained? Edit: And sorry for the late reply! |
Hi @alanakbik, yes, I'll create a PR for these classes next week with all changes needed to regression work properly with Flair. At this time I use the same Label structure to set the value. It store as a string but I cast to float every time I use it. This is how I load my corpus: from flair.data import Sentence, Label, TaggedCorpus
dfCharlson = pd.read_csv('charlson.csv')
sentences = []
for idx in dfCharlson.index:
data = dfCharlson.iloc[idx]
sentence = Sentence(data.text, labels=[Label(value=str(data.target))], use_tokenizer=True)
sentences.append(sentence)
corpus = TaggedCorpus([sentences[t] for t in train], [sentences[d] for d in dev], [sentences[e] for e in test]) train, dev, test are arrays created using sklearn.model_selection.StratifiedKFold |
Ah great - looking forward to the PR :) I'd be very interested to see how well regression works! |
TODO: still need a self-contained MSE and MAE metric
add mean squared error as default for regression
TODO: still need a self-contained MSE and MAE metric
add mean squared error as default for regression
still need unit test for MetricRegression
PR with regression implementation pushed to master! |
…task GH-440: added WASSA 2017 emotion intensity task
Can you, please, add a sample/tutorial for the regression problem? Thanks! |
There is as example in tests/test_text_regressor.py |
Hi, |
Hello @aditya-malte , here has some exemples of regression format |
Hi, |
Hi, @aditya-malte . |
Oh, ok. That’s great then! Some of the function names had me worried that we were binning and passing it as a classification task. |
Hello,
I'd like to tackle regression problems with flair, I think it should be a matter of slightly altering the last layer of the model and the loss function.
I could try to do the modifications, but jus asking if it's a planned feature, or maybe the community can point me to the right layer types to replace.
Best Regards.
The text was updated successfully, but these errors were encountered: