Skip to content

Commit

Permalink
Merge pull request #10 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
hyp1231 authored Jul 15, 2020
2 parents f5cd715 + 49db3e4 commit 45cbb2c
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 40 deletions.
2 changes: 1 addition & 1 deletion config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __getitem__(self, item):
elif item in self.dataset_args:
return self.dataset_args[item]
else:
raise KeyError("There are no parameter named '%s'" % item)
return None

def __setitem__(self, key, value):
if not isinstance(key, str):
Expand Down
70 changes: 37 additions & 33 deletions evaluator/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import abc
import numpy as np
import pandas as pd
from collections import namedtuple
import utils

SCORE_FIELD = 'score'

# 'Precision', 'Hit', 'Recall', 'MAP', 'NDCG', 'MRR'
metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR']}
# 'Precision', 'Hit', 'Recall', 'MAP', 'NDCG', 'MRR', 'AUC'
metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR', 'AUC']}

class AbstractEvaluator(metaclass=abc.ABCMeta):
"""The abstract class of the evaluation module, its subclasses must implement their functions
Expand Down Expand Up @@ -42,6 +39,7 @@ def __init__(self, config, logger, metrics, topk):

self.USER_FIELD = config['USER_ID_FIELD']
self.ITEM_FIELD = config['ITEM_ID_FIELD']
self.neg_ratio = config['test_neg_sample_num']

def recommend(self, df, k):
"""Recommend the top k items for users
Expand All @@ -54,23 +52,30 @@ def recommend(self, df, k):
(list, list): (user_list, item_list)
"""

df['rank'] = df.groupby(self.USER_FIELD)[SCORE_FIELD].rank(method='first', ascending=False)
df['rank'] = df.groupby(self.USER_FIELD)['score'].rank(method='first', ascending=False)
mask = (df['rank'].values > 0) & (df['rank'].values <= k)
topk_df = df[mask]
return topk_df[self.USER_FIELD].values.tolist(), topk_df[self.ITEM_FIELD].values.tolist()

def metric_info(self, df, metric):
def metric_info(self, df, metric, k):
"""Get the result of the metric on the data
Args:
df (pandas.core.frame.DataFrame): merged data which contains user id, item id, score, rank
metric (str): one of the metrics
k (int): top k
Returns:
float: metric result
"""

metric_fuc = getattr(utils, metric)
fuc = getattr(utils, metric)
if metric == 'auc':
metric_fuc = lambda x: fuc(x, self.neg_ratio)
elif metric == 'precision':
metric_fuc = lambda x: fuc(x, k)
else:
metric_fuc = fuc
metric_result = df.groupby(self.USER_FIELD)['rank'].apply(metric_fuc)
return metric_result

Expand All @@ -92,7 +97,7 @@ def evaluate(self, df):
for k in sorted(self.topk)[::-1]:
df['rank'] = np.where(df['rank'].values <= k, df['rank'].values, np.full_like(df['rank'].values, -1))
for metric in self.metrics:
metric_result = self.metric_info(df, metric)
metric_result = self.metric_info(df, metric, k)
score = metric_result.sum() / num_users
key, value = '{}@{}'.format(metric_name[metric], k), score
metric_dict[key] = value
Expand All @@ -107,10 +112,11 @@ def __init__(self, config, logger, metrics):
self.logger = logger
self.metrics = metrics
self.cut_method = ['ceil', 'floor', 'around']
self.metric_cols = [SCORE_FIELD, STAR_FIELD]

self.LABEL_FIELD = config['LABEL_FIELD']
self.USER_FIELD = config['USER_ID_FIELD']
self.ITEM_FIELD = config['ITEM_ID_FIELD']
self.metric_cols = ['score', self.LABEL_FIELD]

def cutoff(self, df, col, method):
"""Cut off the col's values by using the method
Expand Down Expand Up @@ -180,15 +186,12 @@ def __init__(self, config, logger):
self.group_view = config['group_view']
self.eval_metric = config['eval_metric']
self.topk = config['topk']
# TODO 这里的话应该就是topk为空的话就不是topk推荐了
self.logger = logger
self.verbose = True

self.USER_FIELD = config['USER_ID_FIELD']
self.ITEM_FIELD = config['ITEM_ID_FIELD']
self.LABEL_FIELD = config['LABEL_FIELD']
# STAR_FIELD = 'review'
# GROUP_ID = 'group_id'

self._check_args()

Expand Down Expand Up @@ -279,8 +282,7 @@ def get_grouped_data(self, df):
group_names = []
for begin, end in zip(group_view[:-1], group_view[1:]):
group_names.append('({},{}]'.format(begin, end))

group_data = df.groupby(GROUP_ID, sort=True)
group_data = df.groupby('group_id', sort=True)
return zip(group_names, group_data)

def common_evaluate(self, df):
Expand Down Expand Up @@ -322,25 +324,24 @@ def _print(self, message):
if self.verbose:
self.logger.info(message)

def build_result_df(self, rdata, result):
# users, items, scores = self.get_result_pairs(result)
# result_df = pd.DataFrame({USER_FIELD:users, ITEM_FIELD:items, SCORE_FIELD:scores})
# return result_df
def build_evaluate_df(self, rdata, result):

return pd.DataFrame({
df = pd.DataFrame({
self.USER_FIELD: rdata[self.USER_FIELD],
self.ITEM_FIELD: rdata[self.ITEM_FIELD],
SCORE_FIELD: result
'score': result,
self.LABEL_FIELD: rdata[self.LABEL_FIELD]
})
return df

def build_truth_df(self, rdata):
truth_df = pd.DataFrame({
def build_recommend_df(self, rdata, result):
df = pd.DataFrame({
self.USER_FIELD: rdata[self.USER_FIELD],
self.ITEM_FIELD: rdata[self.ITEM_FIELD],
self.LABEL_FIELD: rdata[self.LABEL_FIELD]
'score': result
})
return truth_df[truth_df[self.LABEL_FIELD] >= 1]

return df
def recommend(self, rdata, result, k):
"""Recommend the top k items for users
Expand All @@ -352,8 +353,8 @@ def recommend(self, rdata, result, k):
(list, list): (user_list, item_list)
"""

result_df = self.build_result_df(rdata, result)
return self.evaluator.recommend(result_df, k)
df = self.build_recommend_df(rdata, result)
return self.evaluator.recommend(df, k)

def evaluate(self, result, rdata):
"""Generate metrics results on the dataset
Expand All @@ -365,12 +366,15 @@ def evaluate(self, result, rdata):
dict: a dict
"""

result_df = self.build_result_df(rdata, result)
truth_df = self.build_truth_df(rdata)
df = self.build_evaluate_df(rdata, result)
if self.topk is not None:
result_df['rank'] = result_df.groupby(self.USER_FIELD)[SCORE_FIELD].rank(method='first', ascending=False)
df = pd.merge(truth_df, result_df, on=[self.USER_FIELD, self.ITEM_FIELD], how='left')
df['rank'].fillna(-1, inplace=True)
df['rank'] = df.groupby(self.USER_FIELD)['score'].rank(method='first', ascending=False)
mask = df[self.LABEL_FIELD] > 0
df = df[mask].copy()
# TODO 如果是CTR类的model,可能给的就是label为0,1的数据,那么neg_ratio可能是互不相同的,而算AUC的时候需要用到这些数据
# neg_df = df[~mask].copy()
# neg_df['neg_ratio'] = neg_df.groupby(self.USER_FIELD)[self.LABEL_FIELD].count()
# truth_df = truth_df.merge(neg_df[[self.USER_FIELD, self.]])
if self.group_view is not None:
return self.group_evaluate(df)
return self.common_evaluate(df)
Expand Down
153 changes: 153 additions & 0 deletions model/general_recommender/KNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import numpy as np
from model.abstract_recommender import AbstractRecommender
import scipy.sparse as sp
import torch


class ComputeSimilarity:

def __init__(self, dataMatrix, topK=100, shrink=0, normalize=True):
"""
Computes the cosine similarity on the columns of dataMatrix
If it is computed on URM=|users|x|items|, pass the URM as is.
If it is computed on ICM=|items|x|features|, pass the ICM transposed.
:param dataMatrix:
:param topK:
:param shrink:
:param normalize: If True divide the dot product by the product of the norms
"""
"""
Asymmetric Cosine as described in:
Aiolli, F. (2013, October). Efficient top-n recommendation for very large scale binary rated datasets.
In Proceedings of the 7th ACM conference on Recommender systems (pp. 273-280). ACM.
"""

super(ComputeSimilarity, self).__init__()

self.shrink = shrink
self.normalize = normalize

self.n_rows, self.n_columns = dataMatrix.shape
self.TopK = min(topK, self.n_columns)

self.dataMatrix = dataMatrix.copy()

def compute_similarity(self, block_size=100):
"""
Compute the similarity for the given dataset
:param block_size: divide matrix to n_columns/block_size to calculate cosine_distance
:return: sparse matrix W shape of (self.n_columns, self.n_columns)
"""

values = []
rows = []
cols = []

self.dataMatrix = self.dataMatrix.astype(np.float32)

# Compute sum of squared values to be used in normalization
sumOfSquared = np.array(self.dataMatrix.power(2).sum(axis=0)).ravel()
sumOfSquared = np.sqrt(sumOfSquared)

end_col_local = self.n_columns
start_col_block = 0

# Compute all similarities for each item using vectorization
while start_col_block < end_col_local:

end_col_block = min(start_col_block + block_size, end_col_local)
this_block_size = end_col_block - start_col_block

# All data points for a given item
item_data = self.dataMatrix[:, start_col_block:end_col_block]
item_data = item_data.toarray().squeeze()

if item_data.ndim == 1:
item_data = np.atleast_2d(item_data)

# Compute item similarities
this_block_weights = self.dataMatrix.T.dot(item_data)

for col_index_in_block in range(this_block_size):

if this_block_size == 1:
this_column_weights = this_block_weights
else:
this_column_weights = this_block_weights[:, col_index_in_block]

columnIndex = col_index_in_block + start_col_block
this_column_weights[columnIndex] = 0.0

# Apply normalization and shrinkage, ensure denominator != 0
if self.normalize:
denominator = sumOfSquared[columnIndex] * sumOfSquared + self.shrink + 1e-6
this_column_weights = np.multiply(this_column_weights, 1 / denominator)

elif self.shrink != 0:
this_column_weights = this_column_weights / self.shrink

# Sort indices and select TopK
# Sorting is done in three steps. Faster then plain np.argsort for higher number of items
# - Partition the data to extract the set of relevant items
# - Sort only the relevant items
# - Get the original item index
relevant_items_partition = (-this_column_weights).argpartition(self.TopK - 1)[0:self.TopK]
relevant_items_partition_sorting = np.argsort(-this_column_weights[relevant_items_partition])
top_k_idx = relevant_items_partition[relevant_items_partition_sorting]

# Incrementally build sparse matrix, do not add zeros
notZerosMask = this_column_weights[top_k_idx] != 0.0
numNotZeros = np.sum(notZerosMask)

values.extend(this_column_weights[top_k_idx][notZerosMask])
rows.extend(top_k_idx[notZerosMask])
cols.extend(np.ones(numNotZeros) * columnIndex)

start_col_block += block_size

# End while on columns

W_sparse = sp.csr_matrix((values, (rows, cols)),
shape=(self.n_columns, self.n_columns),
dtype=np.float32)

return W_sparse.tocsc()


class ItemKNN(AbstractRecommender):
def __init__(self, config, dataset):
self.device = config['device']
self.USER_ID = config['USER_ID_FIELD']
self.ITEM_ID = config['ITEM_ID_FIELD']
self.n_users = len(dataset.field2id_token[self.USER_ID])
self.n_items = len(dataset.field2id_token[self.ITEM_ID])

self.interaction_matrix = dataset.train_matrix.tocsr().astype(np.float32)
shape = self.interaction_matrix.shape
assert self.n_users == shape[0] and self.n_items == shape[1]
self.k = config['k']
self.shrink = config['shrink'] if 'shrink' in config else 0.0
self.w = ComputeSimilarity(self.interaction_matrix, topK=self.k, shrink=self.shrink).compute_similarity()
self.pred_mat = self.interaction_matrix.dot(self.w).tolil()

def forward(self, user, item):
pass

def calculate_loss(self, interaction):
pass

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user = user.cpu().numpy().astype(int)
item = item.cpu().numpy().astype(int)
result = []

for index in range(len(user)):
uid = user[index]
iid = item[item]
score = self.pred_mat[uid, iid]
result.append(score)
result = torch.from_numpy(np.array(result)).to(self.device)
return result
2 changes: 1 addition & 1 deletion properties/overall.config
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ checkpoint_dir='saved'
seed=2020

# evaluating
eval_metric=["Recall", "Hit", "MRR"]
eval_metric=["Recall", "Hit", "MRR", "AUC"]
topk=[10, 20]
valid_batch_size=2048
test_batch_size=2048
17 changes: 12 additions & 5 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ def recall(data):
def ndcg(data):
pass

def precision(data):
pass

def auc(data):
pass
def precision(data, k):
return (data > 0).sum() / k

def auc(data, neg_ratio):
if neg_ratio == 0:
return 1
pos_num = (data > 0).sum()
if pos_num == 0:
return 0
pos_ranksum = (pos_num * (neg_ratio + 1) + 1 - data[data > 0]).sum()
neg_num = pos_num * neg_ratio
return (pos_ranksum - pos_num * (pos_num + 1) / 2) / (pos_num * neg_num)

## Loss based Metrics ##

Expand Down

0 comments on commit 45cbb2c

Please sign in to comment.