diff --git a/data/utils.py b/data/utils.py index a23eefce1..5de9d4b30 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,5 +1,6 @@ from .dataloader import * from config import EvalSetting +from utils import ModelType def data_preparation(config, model, dataset): es = EvalSetting(config) @@ -35,7 +36,7 @@ def data_preparation(config, model, dataset): return train_data, test_data, valid_data def dataloader_construct(name, config, eval_setting, dataset, - dl_type='general', dl_format='pointwise', + dl_type=ModelType.GENERAL, dl_format='pointwise', batch_size=1, shuffle=False): if not isinstance(dataset, list): dataset = [dataset] @@ -50,7 +51,7 @@ def dataloader_construct(name, config, eval_setting, dataset, print(eval_setting) print('batch_size = {}, shuffle = {}\n'.format(batch_size, shuffle)) - if dl_type == 'general': + if dl_type == ModelType.GENERAL: DataLoader = GeneralDataLoader else: raise NotImplementedError('dl_type [{}] has not been implemented'.format(dl_type)) diff --git a/evaluator/evaluator.py b/evaluator/evaluator.py index c1a4d9aec..1719e0d99 100644 --- a/evaluator/evaluator.py +++ b/evaluator/evaluator.py @@ -11,8 +11,8 @@ metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR', 'AUC', 'Precision', 'NDCG']} # These metrics are typical in topk recommendations -topk_metric = {'hit', 'recall', 'precision', 'ndcg'} -other_metric = {'auc', 'mrr'} +topk_metric = {'hit', 'recall', 'precision', 'ndcg', 'mrr'} +other_metric = {'auc'} class AbstractEvaluator(metaclass=abc.ABCMeta): """The abstract class of the evaluation module, its subclasses must implement their functions diff --git a/model/general_recommender/bprmf.py b/model/general_recommender/bprmf.py index 3905acccc..0614f9118 100644 --- a/model/general_recommender/bprmf.py +++ b/model/general_recommender/bprmf.py @@ -15,6 +15,7 @@ from model.abstract_recommender import AbstractRecommender from model.loss import BPRLoss +from utils import ModelType class BPRMF(AbstractRecommender): @@ -22,7 +23,7 @@ class BPRMF(AbstractRecommender): def __init__(self, config, dataset): super(BPRMF, self).__init__() - self.type = 'general' + self.type = ModelType.GENERAL self.USER_ID = config['USER_ID_FIELD'] self.ITEM_ID = config['ITEM_ID_FIELD'] diff --git a/utils/__init__.py b/utils/__init__.py index 67235de1a..62a35414e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,4 @@ from .metrics import * from .logger import Logger from .utils import get_local_time, ensure_dir +from .enum_type import * diff --git a/utils/enum_type.py b/utils/enum_type.py new file mode 100644 index 000000000..67dcc6a90 --- /dev/null +++ b/utils/enum_type.py @@ -0,0 +1,8 @@ +from enum import Enum + +class ModelType(Enum): + GENERAL = 1 + SEQUENTIAL = 2 + CONTEXT = 3 + KNOWLEDGE = 4 + SOCIAL = 5 diff --git a/utils/metrics.py b/utils/metrics.py index e5a0f269a..d98e28ef7 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -12,7 +12,7 @@ def hit(rank, label, k): """ return int(any(rank[label] <= k)) -def mrr(rank, label, k=None): +def mrr(rank, label, k): """The MRR (also known as mean reciprocal rank) is a statistic measure for evaluating any process that produces a list of possible responses to a sample of queries, ordered by probability of correctness. @@ -20,8 +20,9 @@ def mrr(rank, label, k=None): """ ground_truth_ranks = rank[label] - if ground_truth_ranks.all(): - return (1 / ground_truth_ranks.min()) + ground_truth_at_k = ground_truth_ranks[ground_truth_ranks <= k] + if ground_truth_at_k.shape[0] > 0: + return (1 / ground_truth_at_k.min()) return 0 def recall(rank, label, k):