Skip to content

Commit

Permalink
Merge pull request #19 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
chenyushuo authored Aug 3, 2020
2 parents f3f9d0b + 273f511 commit 37baf5f
Show file tree
Hide file tree
Showing 16 changed files with 488 additions and 520 deletions.
12 changes: 6 additions & 6 deletions config/eval_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def set_splitting(self, strategy='none', **kwargs):
self.split_args = {'strategy': strategy}
self.split_args.update(kwargs)

def leave_one_out(self):
def leave_one_out(self, leave_one_num=1):
if self.group_field is None:
raise ValueError('Leave one out request grouped dataset, please set group field.')
self.set_splitting(strategy='loo')
self.set_splitting(strategy='loo', leave_one_num=leave_one_num)

def split_by_ratio(self, ratios):
if not isinstance(ratios, list):
Expand Down Expand Up @@ -184,15 +184,15 @@ def TO_RS(self, ratios=[0.8, 0.1, 0.1]):
self.temporal_ordering()
self.split_by_ratio(ratios)

def RO_LS(self):
def RO_LS(self, leave_one_num=1):
self.group_by_user()
self.random_ordering()
self.leave_one_out()
self.leave_one_out(leave_one_num=leave_one_num)

def TO_LS(self):
def TO_LS(self, leave_one_num=1):
self.group_by_user()
self.temporal_ordering()
self.leave_one_out()
self.leave_one_out(leave_one_num=leave_one_num)

def uni100(self, real_time=False):
self.neg_sample_by(100, real_time=real_time)
Expand Down
22 changes: 21 additions & 1 deletion data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,26 @@ def split_by_ratio(self, ratios, group_by=None):
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def leave_one_out(self, group_by, leave_one_num=1):
if group_by is None:
raise ValueError('leave one out strategy require a group field')

grouped_inter_feat_index = self.inter_feat.groupby(by=group_by).groups.values()
next_index = [[] for i in range(leave_one_num + 1)]
for grouped_index in grouped_inter_feat_index:
grouped_index = list(grouped_index)
tot_cnt = len(grouped_index)
legal_leave_one_num = min(leave_one_num, tot_cnt - 1)
pr = tot_cnt - legal_leave_one_num
next_index[0].extend(grouped_index[:pr])
for i in range(legal_leave_one_num):
next_index[i + 1].append(grouped_index[pr])
pr += 1

next_df = [self.inter_feat.loc[index].reset_index(drop=True) for index in next_index]
next_ds = [self.copy(_) for _ in next_df]
return next_ds

def shuffle(self):
self.inter_feat = self.inter_feat.sample(frac=1).reset_index(drop=True)

Expand All @@ -379,7 +399,7 @@ def build(self, eval_setting):
elif split_args['strategy'] == 'by_value':
raise NotImplementedError()
elif split_args['strategy'] == 'loo':
raise NotImplementedError()
datasets = self.leave_one_out(group_by=group_field, leave_one_num=split_args['leave_one_num'])
else:
datasets = self

Expand Down
9 changes: 5 additions & 4 deletions data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
def data_preparation(config, logger, model, dataset, save=False):
es_str = [_.strip() for _ in config['eval_setting'].split(',')]
es = EvalSetting(config)
if 'RS' in es_str[0]:
getattr(es, es_str[0])(ratios=config['split_ratio'])
else:
getattr(es, es_str[0])()

kargs = {}
if 'RS' in es_str[0]: kargs['ratios'] = config['split_ratio']
if 'LS' in es_str[0]: kargs['leave_one_num'] = config['leave_one_num']
getattr(es, es_str[0])(**kargs)

builded_datasets = dataset.build(es)
train_dataset, valid_dataset, test_dataset = builded_datasets
Expand Down
2 changes: 1 addition & 1 deletion evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .evaluator import Evaluator
from .evaluator import TopKEvaluator, LossEvaluator
1 change: 1 addition & 0 deletions evaluator/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cpp_evaluator import BaseLossEvaluator, BaseCTREvaluator, BaseTopKEvaluator
34 changes: 34 additions & 0 deletions evaluator/cpp/cpp_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
class CTREvaluator(object):

def __init__(self):
pass

def eval_metric(self):
'''
python的形式?
'''
pass


class TOPKEvaluator(object):

def __init__(self):
pass

def eval_metric(self):
'''
cython的形式?
'''
pass


class LossEvaluator(object):

def __init__(self):
pass

def eval_metric(self):
'''
cython的形式?
'''
pass
Loading

0 comments on commit 37baf5f

Please sign in to comment.