Skip to content

Commit

Permalink
Merge pull request #564 from nlp-pucrs/regression
Browse files Browse the repository at this point in the history
Flair Regression
  • Loading branch information
Alan Akbik authored Apr 16, 2019
2 parents f1c3d3c + 330dd56 commit 11850ee
Show file tree
Hide file tree
Showing 9 changed files with 1,954 additions and 4 deletions.
4 changes: 4 additions & 0 deletions flair/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class NLPTask(Enum):
TREC_6 = "trec-6"
TREC_50 = "trec-50"

# text regression format
REGRESSION = 'regression'


class NLPTaskDataFetcher:
@staticmethod
Expand Down Expand Up @@ -210,6 +213,7 @@ def load_corpus(
NLPTask.AG_NEWS.value,
NLPTask.TREC_6.value,
NLPTask.TREC_50.value,
NLPTask.REGRESSION.value,
]:
use_tokenizer: bool = False if task in [
NLPTask.TREC_6.value,
Expand Down
59 changes: 59 additions & 0 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import flair
import torch
import torch.nn as nn
from typing import List, Union
from flair.training_utils import clear_embeddings
from flair.data import Sentence, Label
import logging

log = logging.getLogger('flair')

class TextRegressor(flair.models.TextClassifier):

def __init__(self,
document_embeddings: flair.embeddings.DocumentEmbeddings,
label_dictionary: flair.data.Dictionary,
multi_label: bool):

super(TextRegressor, self).__init__(document_embeddings=document_embeddings, label_dictionary=flair.data.Dictionary(), multi_label=multi_label)

log.info('Using REGRESSION - experimental')

self.loss_function = nn.MSELoss()

def _labels_to_indices(self, sentences: List[Sentence]):
indices = [
torch.FloatTensor([float(label.value) for label in sentence.labels])
for sentence in sentences
]

vec = torch.cat(indices, 0)
if torch.cuda.is_available():
vec = vec.cuda()

return vec

def forward_labels_and_loss(self, sentences: Union[Sentence, List[Sentence]]) -> (List[List[float]], torch.tensor):
scores = self.forward(sentences)
loss = self._calculate_loss(scores, sentences)
return scores, loss

def predict(self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32) -> List[Sentence]:

with torch.no_grad():
if type(sentences) is Sentence:
sentences = [sentences]

filtered_sentences = self._filter_empty_sentences(sentences)

batches = [filtered_sentences[x:x + mini_batch_size] for x in range(0, len(filtered_sentences), mini_batch_size)]

for batch in batches:
scores = self.forward(batch)

for (sentence, score) in zip(batch, scores.tolist()):
sentence.labels = [Label(value=str(score[0]))]

clear_embeddings(batch)

return sentences
147 changes: 147 additions & 0 deletions flair/trainers/trainer_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import flair
import torch
import torch.nn as nn

from typing import List, Union
from flair.training_utils import MetricRegression, EvaluationMetric, clear_embeddings, log_line
from flair.models.text_regression_model import TextRegressor
from flair.data import Sentence, Label
from pathlib import Path
import logging

log = logging.getLogger('flair')

class RegressorTrainer(flair.trainers.ModelTrainer):

def train(self,
base_path: Union[Path, str],
evaluation_metric: EvaluationMetric = EvaluationMetric.MEAN_SQUARED_ERROR,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
eval_mini_batch_size: int = None,
max_epochs: int = 100,
anneal_factor: float = 0.5,
patience: int = 3,
anneal_against_train_loss: bool = True,
train_with_dev: bool = False,
monitor_train: bool = False,
embeddings_in_memory: bool = True,
checkpoint: bool = False,
save_final_model: bool = True,
anneal_with_restarts: bool = False,
test_mode: bool = False,
param_selection_mode: bool = False,
**kwargs
) -> dict:

return super(RegressorTrainer, self).train(
base_path=base_path,
evaluation_metric=evaluation_metric,
learning_rate=learning_rate,
mini_batch_size=mini_batch_size,
eval_mini_batch_size=eval_mini_batch_size,
max_epochs=max_epochs,
anneal_factor=anneal_factor,
patience=patience,
anneal_against_train_loss=anneal_against_train_loss,
train_with_dev=train_with_dev,
monitor_train=monitor_train,
embeddings_in_memory=embeddings_in_memory,
checkpoint=checkpoint,
save_final_model=save_final_model,
anneal_with_restarts=anneal_with_restarts,
test_mode=test_mode,
param_selection_mode=param_selection_mode)

@staticmethod
def _evaluate_text_regressor(model: flair.nn.Model,
sentences: List[Sentence],
eval_mini_batch_size: int = 32,
embeddings_in_memory: bool = False,
out_path: Path = None) -> (dict, float):

with torch.no_grad():
eval_loss = 0

batches = [sentences[x:x + eval_mini_batch_size] for x in
range(0, len(sentences), eval_mini_batch_size)]

metric = MetricRegression('Evaluation')

lines: List[str] = []
for batch in batches:

scores, loss = model.forward_labels_and_loss(batch)

true_values = []
for sentence in batch:
for label in sentence.labels:
true_values.append(float(label.value))

results = []
for score in scores:
if type(score[0]) is Label:
results.append(float(score[0].score))
else:
results.append(float(score[0]))

clear_embeddings(batch, also_clear_word_embeddings=not embeddings_in_memory)

eval_loss += loss

metric.true.extend(true_values)
metric.pred.extend(results)

eval_loss /= len(sentences)

##TODO: not saving lines yet
if out_path is not None:
with open(out_path, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))

return metric, eval_loss


def _calculate_evaluation_results_for(self,
dataset_name: str,
dataset: List[Sentence],
evaluation_metric: EvaluationMetric,
embeddings_in_memory: bool,
eval_mini_batch_size: int,
out_path: Path = None):

metric, loss = RegressorTrainer._evaluate_text_regressor(self.model, dataset, eval_mini_batch_size=eval_mini_batch_size,
embeddings_in_memory=embeddings_in_memory, out_path=out_path)

mse = metric.mean_squared_error()
mae = metric.mean_absolute_error()

log.info(f'{dataset_name:<5}: loss {loss:.8f} - mse {mse:.4f} - mae {mae:.4f}')

return metric, loss

def final_test(self,
base_path: Path,
embeddings_in_memory: bool,
evaluation_metric: EvaluationMetric,
eval_mini_batch_size: int):

log_line(log)
log.info('Testing using best model ...')

self.model.eval()

if (base_path / 'best-model.pt').exists():
self.model = TextRegressor.load_from_file(base_path / 'best-model.pt')

test_metric, test_loss = self._evaluate_text_regressor(self.model, self.corpus.test, eval_mini_batch_size=eval_mini_batch_size,
embeddings_in_memory=embeddings_in_memory)

log.info(f'AVG: mse: {test_metric.mean_squared_error():.4f} - '
f'mae: {test_metric.mean_absolute_error():.4f} - '
f'pearson: {test_metric.pearsonr():.4f} - '
f'spearman: {test_metric.spearmanr():.4f}')

log_line(log)

return test_metric.mean_squared_error()
64 changes: 60 additions & 4 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import List
from flair.data import Dictionary, Sentence
from functools import reduce
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr, spearmanr


class Metric(object):
Expand Down Expand Up @@ -171,11 +173,65 @@ def __str__(self):
return "\n".join(all_lines)


class MetricRegression(object):

def __init__(self, name):
self.name = name

self.true = []
self.pred = []

def mean_squared_error(self):
return mean_squared_error(self.true, self.pred)

def mean_absolute_error(self):
return mean_absolute_error(self.true, self.pred)

def pearsonr(self):
return pearsonr(self.true, self.pred)[0]

def spearmanr(self):
return spearmanr(self.true, self.pred)[0]

## dummy return to fulfill trainer.train() needs
def micro_avg_f_score(self):
return self.mean_squared_error()

def to_tsv(self):
return '{}\t{}\t{}\t{}'.format(
self.mean_squared_error(),
self.mean_absolute_error(),
self.pearsonr(),
self.spearmanr(),
)

@staticmethod
def tsv_header(prefix=None):
if prefix:
return '{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN'.format(
prefix)

return 'MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN'

@staticmethod
def to_empty_tsv():
return '\t_\t_\t_\t_'

def __str__(self):
line = 'mean squared error: {0:.4f} - mean absolute error: {1:.4f} - pearson: {2:.4f} - spearman: {3:.4f}'.format(
self.mean_squared_error(),
self.mean_absolute_error(),
self.pearsonr(),
self.spearmanr())
return line


class EvaluationMetric(Enum):
MICRO_ACCURACY = "micro-average accuracy"
MICRO_F1_SCORE = "micro-average f1-score"
MACRO_ACCURACY = "macro-average accuracy"
MACRO_F1_SCORE = "macro-average f1-score"
MICRO_ACCURACY = 'micro-average accuracy'
MICRO_F1_SCORE = 'micro-average f1-score'
MACRO_ACCURACY = 'macro-average accuracy'
MACRO_F1_SCORE = 'macro-average f1-score'
MEAN_SQUARED_ERROR = 'mean squared error'


class WeightExtractor(object):
Expand Down
14 changes: 14 additions & 0 deletions tests/resources/tasks/regression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## REGRESSION

Data is taken from [here](http://saifmohammad.com/WebPages/EmotionIntensity-SharedTask.html).

The dataset contains a collection of tweets with joy intensity value.
We took the joy dataset and converted it to the expected format of our data fetcher:
```
__label__<joy_intensity> <text>
```

#### Publication About the Dataset

* Emotion Intensities in Tweets. Saif M. Mohammad and Felipe Bravo-Marquez. In Proceedings of the sixth joint conference on lexical and computational semantics (*Sem), August 2017, Vancouver, Canada.
* WASSA-2017 Shared Task on Emotion Intensity. Saif M. Mohammad and Felipe Bravo-Marquez. In Proceedings of the EMNLP 2017 Workshop on Computational Approaches to Subjectivity, Sentiment, and Social Media (WASSA), September 2017, Copenhagen, Denmark.
Loading

0 comments on commit 11850ee

Please sign in to comment.