Skip to content

Commit

Permalink
Merge pull request #2 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Update to master
  • Loading branch information
hyp1231 authored Jun 30, 2020
2 parents c0fc9e6 + 8a1273a commit 4bac813
Show file tree
Hide file tree
Showing 13 changed files with 461 additions and 110 deletions.
88 changes: 86 additions & 2 deletions data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def split_by_ratio(self, train_ratio, test_ratio, valid_ratio=0,
train_inter[k] = self.interaction[k][:train_cnt]
test_inter[k] = self.interaction[k][train_cnt : train_cnt+test_cnt]
if valid_ratio > 0:
valid_inter = self.interaction[k][train_cnt+test_cnt:]
valid_inter[k] = self.interaction[k][train_cnt+test_cnt:]

if valid_ratio > 0:
return Data(config=self.config, interaction=train_inter, batch_size=train_batch_size, sampler=self.sampler), \
Data(config=self.config, interaction=test_inter, batch_size=test_batch_size, sampler=self.sampler), \
Expand Down Expand Up @@ -123,3 +123,87 @@ def remove_lower_value_by_key(self, key, min_remain_value=0):

return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=new_sampler)

def neg_sample_1by1(self):
new_inter = {
'user_id': [],
'pos_item_id': [],
'neg_item_id': []
}
for i in range(self.__len__()):
uid = self.interaction['user_id'][i].item()
new_inter['user_id'].append(uid)
new_inter['pos_item_id'].append(self.interaction['item_id'][i].item())
new_inter['neg_item_id'].append(self.sampler.sample_by_user_id(uid)[0])
for k in new_inter:
new_inter[k] = torch.LongTensor(new_inter[k])
return Data(
config=self.config,
interaction=new_inter,
batch_size=self.batch_size,
sampler=self.sampler
)

# def neg_sample_to(self, num):
# new_inter = {
# 'user_id': [],
# 'item_list': [],
# 'label': []
# }

# uid2itemlist = {}
# for i in range(self.__len__()):
# uid = self.interaction['user_id'][i].item()
# iid = self.interaction['item_id'][i].item()
# if uid not in uid2itemlist:
# uid2itemlist[uid] = []
# uid2itemlist[uid].append(iid)
# for uid in uid2itemlist:
# pos_num = len(uid2itemlist[uid])
# if pos_num >= num:
# uid2itemlist[uid] = uid2itemlist[uid][:num-1]
# pos_num = num - 1
# neg_item_id = self.sampler.sample_by_user_id(uid, num - pos_num)
# uid2itemlist[uid] += neg_item_id
# label = [1] * pos_num + [0] * (num - pos_num)
# new_inter['user_id'].append(uid)
# new_inter['item_list'].append(uid2itemlist[uid])
# new_inter['label'].append(label)

# for k in new_inter:
# new_inter[k] = torch.LongTensor(new_inter[k])

# return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=self.sampler)

def neg_sample_to(self, num):
new_inter = {
'user_id': [],
'item_id': [],
'label': []
}

uid2itemlist = {}
for i in range(self.__len__()):
uid = self.interaction['user_id'][i].item()
iid = self.interaction['item_id'][i].item()
if uid not in uid2itemlist:
uid2itemlist[uid] = []
uid2itemlist[uid].append(iid)
for uid in uid2itemlist:
pos_num = len(uid2itemlist[uid])
if pos_num >= num:
uid2itemlist[uid] = uid2itemlist[uid][:num-1]
pos_num = num - 1
neg_item_id = self.sampler.sample_by_user_id(uid, num - pos_num)
for iid in uid2itemlist[uid]:
new_inter['user_id'].append(uid)
new_inter['item_id'].append(iid)
new_inter['label'].append(1)
for iid in neg_item_id:
new_inter['user_id'].append(uid)
new_inter['item_id'].append(iid)
new_inter['label'].append(0)

for k in new_inter:
new_inter[k] = torch.LongTensor(new_inter[k])

return Data(config=self.config, interaction=new_inter, batch_size=self.batch_size, sampler=self.sampler)
35 changes: 31 additions & 4 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,37 @@ def preprocessing(self, workflow=None):
Preprocessing of the dataset
'''
cur = self.dataset
for func in workflow:
if func == 'split':
cur = cur.split(self.config['process.ratio'])
return cur
train_data = test_data = valid_data = None
for func in workflow['preprocessing']:
if func == 'remove_lower_value_by_key':
cur = cur.remove_lower_value_by_key(
key=self.config['process.remove_lower_value_by_key.key'],
min_remain_value=self.config['process.remove_lower_value_by_key.min_remain_value']
)
elif func == 'split_by_ratio':
train_data, test_data, valid_data = cur.split_by_ratio(
train_ratio=self.config['process.split_by_ratio.train_ratio'],
test_ratio=self.config['process.split_by_ratio.test_ratio'],
valid_ratio=self.config['process.split_by_ratio.valid_ratio'],
train_batch_size=self.config['train_batch_size'],
test_batch_size=self.config['test_batch_size'],
valid_batch_size=self.config['valid_batch_size']
)
break

for func in workflow['train']:
if func == 'neg_sample_1by1':
train_data = train_data.neg_sample_1by1()

for func in workflow['test']:
if func == 'neg_sample_to':
test_data = test_data.neg_sample_to(num=self.config['process.neg_sample_to.num'])

for func in workflow['valid']:
if func == 'neg_sample_to':
valid_data = valid_data.neg_sample_to(num=self.config['process.neg_sample_to.num'])

return train_data, test_data, valid_data

class UIRTDataset(AbstractDataset):
def __init__(self, config):
Expand Down
120 changes: 94 additions & 26 deletions evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
# here put the import lib

import pandas as pd
import numpy as np
from utils import Logger
import utils



# 'Precision', 'Hit', 'Recall', 'MAP', 'NDCG', 'MRR'
metric_name = {metric.lower() : metric for metric in ['Hit', 'Recall', 'MRR']}

Expand All @@ -22,6 +25,18 @@ class BaseEvaluator(object):
def __init__(self):
pass

def merge_data(self, result, test_data):
"""Merge result and test_data which can help our evaluation
Args:
result (tuple): a tuple of (users, items), both `users` and `items` are lists and they have the same length.
test_data ([type]): TODO [description]
Returns:
(pandas.core.frame.DataFrame): such as
"""
raise NotImplementedError

def metrics_info(self, merged_data):
"""Get all metrics information for the merged data.
Expand All @@ -30,6 +45,18 @@ def metrics_info(self, merged_data):
"""
raise NotImplementedError

def collect(self, result_list, batch_size_list):
"""Add the results of all batches
Args:
result_list (list): list which contains
batch_size_list (list): [description]
Raises:
NotImplementedError: [description]
"""
raise NotImplementedError

def evaluate(self, result, test_data):
"""Evaluate the result generated by the train model and get metrics information on the specified data
Expand All @@ -38,16 +65,16 @@ def evaluate(self, result, test_data):
test_data ([type]) : TODO
Returns:
A string which consists of all information about the result on the test_data
(dict) : A dict which consists of all information about the result on the test_data
"""
return NotImplementedError
raise NotImplementedError

# TODO 这里应该是加速的重点
class UnionEvaluator(BaseEvaluator):
"""`UnionEvaluator` evaluates results on ungrouped data.
"""
def __init__(self, eval_metric, topk, workers):
def __init__(self, logger, eval_metric, topk, workers):
"""[summary]
Args:
Expand All @@ -56,6 +83,7 @@ def __init__(self, eval_metric, topk, workers):
workers ([type]): [description]
"""
super(UnionEvaluator, self).__init__()
self.logger = logger
self.topk = topk
self.eval_metric = eval_metric
self.workers = workers
Expand All @@ -72,8 +100,19 @@ def get_ground_truth(self, users, test_data):
items (list): users' ground truth.
"""
# TODO 对接
users, items = test_data
return users, items
users = set(users)
users_lst = []
items_lst = []

for interaction in test_data:
for i in range(interaction['user_id'].shape[0]):
uid = interaction['user_id'][i].item()
iid = interaction['item_id'][i].item()
label = interaction['label'][i].item()
if label == 0 or uid not in users: continue
users_lst.append(uid)
items_lst.append(iid)
return users_lst, items_lst

def get_result_pairs(self, result):
"""[summary]
Expand Down Expand Up @@ -119,13 +158,34 @@ def metric_info(self, merged_data):
str: A string consist of all metrics information
"""

metric_info = []
metric_dict = {}
for k in self.topk:
for method in self.eval_metric:
eval_fuc = getattr(utils, method)
score = eval_fuc(merged_data, k)
metric_info.append('{:>5}@{} : {:5f}'.format(metric_name[method], k, score))
return '\t'.join(metric_info)
key, value = '{}@{}'.format(metric_name[method], k), score
metric_dict[key] = value
return metric_dict

def collect(self, result_list, batch_size_list):

tmp_result_list = []
keys = list(result_list[0].keys())
for result in result_list:
tmp_result_list.append(list(result.values()))


result_matrix = np.array(tmp_result_list)
batch_size_matrix = np.array(batch_size_list).reshape(-1, 1)
assert result_matrix.shape[0] == batch_size_matrix.shape[0]

weighted_matrix = result_matrix * batch_size_matrix

metric_list = (np.sum(weighted_matrix, axis=0) / np.sum(batch_size_matrix)).tolist()
metric_dict = {}
for method, score in zip(keys, metric_list):
metric_dict[method] = score
return metric_dict

def evaluate(self, result, test_data):
"""Evaluate `model`.
Expand All @@ -139,8 +199,8 @@ def evaluate(self, result, test_data):
Hit@5 : 0.484848 Recall@5 : 0.162734 Hit@7 : 0.727273 Recall@7 : 0.236760`.
"""
merged_data = self.merge_data(result, test_data)
info_str = self.metric_info(merged_data)
return info_str
metric_dict = self.metric_info(merged_data)
return metric_dict

class GroupEvaluator(UnionEvaluator):
"""`GroupedEvaluator` evaluates results in user groups.
Expand All @@ -152,7 +212,7 @@ class GroupEvaluator(UnionEvaluator):
For example, if `group_view = [10, 30, 50, 100]`, users will be split into
four groups: `(0, 10]`, `(10, 30]`, `(30, 50]`, `(50, 100], (100, -]`.
"""
def __init__(self, group_view, eval_metric, topk, workers):
def __init__(self, logger, group_view, eval_metric, topk, workers):
"""[summary]
Args:
Expand All @@ -162,7 +222,7 @@ def __init__(self, group_view, eval_metric, topk, workers):
performance.
workers (int): `workers` controls the number of threads.
"""
super(GroupEvaluator, self).__init__(eval_metric, topk, workers)
super(GroupEvaluator, self).__init__(logger, eval_metric, topk, workers)
self.group_view = group_view
self.group_names = self.get_group_names()
# print(self.group_names)
Expand Down Expand Up @@ -206,13 +266,16 @@ def evaluate_groups(self, groups):
Returns:
str: A string consist of all results.
"""

### XXX 改成字符串有bug所有这里暂时没做
info_list = []
for index, group in groups:
info_str = self.metric_info(group)
info_list.append('{:<5}\t{}'.format(self.group_names[index], info_str))
return '\n'.join(info_list)

def collect(self, result_list, batch_size_list):
raise NotImplementedError

def evaluate(self, result, test_data):
"""Evaluate `model`.
Expand All @@ -229,8 +292,8 @@ def evaluate(self, result, test_data):
merged_data = self.merge_data(result, test_data)
grouped_data = self.get_grouped_data(merged_data, test_data)
groups = self.groupby(grouped_data, 'group_id')
info_str = self.evaluate_groups(groups)
return info_str
metric_dict = self.evaluate_groups(groups)
return metric_dict

class Evaluator(BaseEvaluator):
"""`Evaluator` is the interface to evaluate models.
Expand All @@ -250,7 +313,7 @@ class Evaluator(BaseEvaluator):
in `training data`. It is configurable via the argument `group_view`.
"""
def __init__(self, config):
def __init__(self, config, logger):
"""Initialize the evaluator by the global configuration file.
Args:
Expand All @@ -263,18 +326,19 @@ def __init__(self, config):
"""
super(Evaluator, self).__init__()

self.group_view = config['group_view']
self.eval_metric = config['metric']
self.topk = config['topk']
self.workers = config['workers'] # TODO 多进程,但是windows可能有点难搞, 貌似要在__main__里

self.group_view = config['eval.group_view']
self.eval_metric = config['eval.metric']
self.topk = config['eval.topk']
self.workers = 0 # TODO 多进程,但是windows可能有点难搞, 貌似要在__main__里
self.logger = logger

# XXX 这种类型检查应该放到哪呢?放在config部分一次判断,还是分散在各模块中呢?
self._check_args()

if self.group_view is not None:
self.evaluator = GroupEvaluator(self.group_view, self.eval_metric, self.topk, self.workers)
self.evaluator = GroupEvaluator(self.logger, self.group_view, self.eval_metric, self.topk, self.workers)
else:
self.evaluator = UnionEvaluator(self.eval_metric, self.topk, self.workers)
self.evaluator = UnionEvaluator(self.logger, self.eval_metric, self.topk, self.workers)

def _check_args(self):

Expand Down Expand Up @@ -323,8 +387,11 @@ def evaluate(self, result, test_data):
Returns:
str: A string consist of all results
"""
info_str = self.evaluator.evaluate(result, test_data)
print(info_str)
metric_dict = self.evaluator.evaluate(result, test_data)
return metric_dict

def collect(self, result_list, batch_size_list):
return self.evaluator.collect(result_list, batch_size_list)

def __str__(self):
return 'The evaluator will evaluate test_data on {} at {}'.format(', '.join(self.eval_metric), ', '.join(map(str, self.topk)))
Expand All @@ -333,8 +400,9 @@ def __repr__(self):
return self.__str__()

def __enter__(self):
print('Evaluate Start...')
self.logger.info('Evaluate Start...')
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.info('Evaluate End...')
pass
Loading

0 comments on commit 4bac813

Please sign in to comment.