Skip to content

Commit

Permalink
Merge pull request #20 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
chenyushuo authored Aug 4, 2020
2 parents 232a5a7 + a43bc23 commit 85d304c
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 319 deletions.
4 changes: 3 additions & 1 deletion evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .evaluator import TopKEvaluator, LossEvaluator
from .loss_evaluator import *
from .topk_evaluator import *
from .metrics import *
1 change: 0 additions & 1 deletion evaluator/cpp/__init__.py

This file was deleted.

34 changes: 0 additions & 34 deletions evaluator/cpp/cpp_evaluator.py

This file was deleted.

130 changes: 0 additions & 130 deletions evaluator/evaluator.py

This file was deleted.

67 changes: 67 additions & 0 deletions evaluator/loss_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from .metrics import metrics_dict
import numpy as np
import torch

# These metrics are typical in loss recommendations
loss_metrics = {metric.lower(): metric for metric in ['AUC', 'RMSE', 'MAE', 'LOGLOSS']}


class LossEvaluator(object):

def __init__(self, config, logger):
self.metrics = config['metrics']
self.label_field = config['LABEL_FIELD']

def evaluate(self, interaction, pred_scores):
"""evalaute the loss metrics
Args:
true_scores (tensor): the true scores' list
pred_scores (tensor): the predict scores' list
Returns:
dict: such as {'AUC': 0.83}
"""
true_scores = interaction[self.label_field].cuda()
return torch.stack((true_scores, pred_scores.detach()), dim=1)

def collect(self, batch_matrix_list, *args):

concat = torch.cat(batch_matrix_list, dim=0).cpu().numpy()

trues = concat[:, 0]
preds = concat[:, 1]

# get metrics
metric_dict = {}
result_list = self.eval_metrics(trues, preds)
for metric, value in zip(self.metrics, result_list):
key = '{}'.format(metric)
metric_dict[key] = value
return metric_dict

def _check_args(self):

# Check eval_metric
if isinstance(self.metrics, (str, list)):
if isinstance(self.metrics, str):
self.metrics = [self.metrics]
else:
raise TypeError('eval_metric must be str or list')

# Convert metric to lowercase
for m in self.metrics:
if m.lower() not in loss_metrics:
raise ValueError("There is no user grouped topk metric named {}!".format(m))
self.metrics = [metric.lower() for metric in self.metrics]

def metrics_info(self, trues, preds):
result_list = []
for metric in self.metrics:
metric_fuc = metrics_dict[metric.lower()]
result = metric_fuc(trues, preds)
result_list.append(result)
return result_list

def eval_metrics(self, trues, preds):
return self.metrics_info(trues, preds)
36 changes: 23 additions & 13 deletions evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
mean_squared_error
)

"""Function name and function mapper.
Useful when we have to serialize evaluation metric names
and call the functions based on deserialized names
"""

# TopK Metrics #

Expand Down Expand Up @@ -114,40 +110,52 @@ def auc(trues, preds):


def mae(trues, preds):
"""[summary]
"""Mean absolute error regression loss
url:
url:https://en.wikipedia.org/wiki/Mean_absolute_error
"""
return mean_absolute_error(trues, preds)


def rmse(trues, preds):
"""[summary]
"""Mean std error regression loss
url:
url:https://en.wikipedia.org/wiki/Root-mean-square_deviation
"""
return np.sqrt(mean_squared_error(trues, preds))


def log_loss_(trues, preds):
"""Log loss, aka logistic loss or cross-entropy loss
url:http://wiki.fast.ai/index.php/Log_Loss
"""
# XXX something wrong
return log_loss(trues, preds)

# Item based Metrics #


def coverage(n_items, ):
pass
raise NotImplementedError


def gini_index():
pass
raise NotImplementedError


def shannon_entropy():
pass
raise NotImplementedError


def diversity():
pass
raise NotImplementedError


"""Function name and function mapper.
Useful when we have to serialize evaluation metric names
and call the functions based on deserialized names
"""
metrics_dict = {
'ndcg': ndcg,
'hit': hit,
Expand All @@ -156,5 +164,7 @@ def diversity():
'recall': recall,
'mrr': mrr,
'rmse': rmse,
'mae': mae
'mae': mae,
'logloss': log_loss_,
'auc': auc
}
1 change: 0 additions & 1 deletion evaluator/python/__init__.py

This file was deleted.

65 changes: 0 additions & 65 deletions evaluator/python/python_evaluator.py

This file was deleted.

Loading

0 comments on commit 85d304c

Please sign in to comment.