Skip to content

Commit

Permalink
Merge pull request #21 from RUCAIBox/master
Browse files Browse the repository at this point in the history
get latest code
  • Loading branch information
linzihan-backforward authored Jul 23, 2020
2 parents d910a56 + ea951b3 commit e1ea65d
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 8 deletions.
5 changes: 3 additions & 2 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dataloader import *
from config import EvalSetting
from utils import ModelType

def data_preparation(config, model, dataset):
es = EvalSetting(config)
Expand Down Expand Up @@ -35,7 +36,7 @@ def data_preparation(config, model, dataset):
return train_data, test_data, valid_data

def dataloader_construct(name, config, eval_setting, dataset,
dl_type='general', dl_format='pointwise',
dl_type=ModelType.GENERAL, dl_format='pointwise',
batch_size=1, shuffle=False):
if not isinstance(dataset, list):
dataset = [dataset]
Expand All @@ -50,7 +51,7 @@ def dataloader_construct(name, config, eval_setting, dataset,
print(eval_setting)
print('batch_size = {}, shuffle = {}\n'.format(batch_size, shuffle))

if dl_type == 'general':
if dl_type == ModelType.GENERAL:
DataLoader = GeneralDataLoader
else:
raise NotImplementedError('dl_type [{}] has not been implemented'.format(dl_type))
Expand Down
4 changes: 2 additions & 2 deletions evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR', 'AUC', 'Precision', 'NDCG']}

# These metrics are typical in topk recommendations
topk_metric = {'hit', 'recall', 'precision', 'ndcg'}
other_metric = {'auc', 'mrr'}
topk_metric = {'hit', 'recall', 'precision', 'ndcg', 'mrr'}
other_metric = {'auc'}

class AbstractEvaluator(metaclass=abc.ABCMeta):
"""The abstract class of the evaluation module, its subclasses must implement their functions
Expand Down
3 changes: 2 additions & 1 deletion model/general_recommender/bprmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

from model.abstract_recommender import AbstractRecommender
from model.loss import BPRLoss
from utils import ModelType


class BPRMF(AbstractRecommender):

def __init__(self, config, dataset):
super(BPRMF, self).__init__()

self.type = 'general'
self.type = ModelType.GENERAL

self.USER_ID = config['USER_ID_FIELD']
self.ITEM_ID = config['ITEM_ID_FIELD']
Expand Down
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .metrics import *
from .logger import Logger
from .utils import get_local_time, ensure_dir
from .enum_type import *
8 changes: 8 additions & 0 deletions utils/enum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum

class ModelType(Enum):
GENERAL = 1
SEQUENTIAL = 2
CONTEXT = 3
KNOWLEDGE = 4
SOCIAL = 5
7 changes: 4 additions & 3 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ def hit(rank, label, k):
"""
return int(any(rank[label] <= k))

def mrr(rank, label, k=None):
def mrr(rank, label, k):
"""The MRR (also known as mean reciprocal rank) is a statistic measure for evaluating any process that produces a list
of possible responses to a sample of queries, ordered by probability of correctness.
url:https://en.wikipedia.org/wiki/Mean_reciprocal_rank
"""
ground_truth_ranks = rank[label]
if ground_truth_ranks.all():
return (1 / ground_truth_ranks.min())
ground_truth_at_k = ground_truth_ranks[ground_truth_ranks <= k]
if ground_truth_at_k.shape[0] > 0:
return (1 / ground_truth_at_k.min())
return 0

def recall(rank, label, k):
Expand Down

0 comments on commit e1ea65d

Please sign in to comment.