Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Bug fix, code format && comments format #499

Merged
merged 6 commits into from
Nov 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _neg_sample_by_point_wise_sampling(self, uid_field, iid_field, neg_iids, int
def get_pos_len_list(self):
"""
Returns:
np.ndarray or list: Number of positive item for each user in a training/evaluating epoch.
np.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
return self.uid2items_num

Expand Down Expand Up @@ -289,6 +289,6 @@ def _neg_sampling(self, uid2index, show_progress=False):
def get_pos_len_list(self):
"""
Returns:
np.ndarray or list: Number of positive item for each user in a training/evaluating epoch.
np.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
return self.uid2items_num
2 changes: 1 addition & 1 deletion recbole/data/dataloader/neg_sample_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _neg_sampling(self, inter_feat):
def get_pos_len_list(self):
"""
Returns:
np.ndarray or list: Number of positive item for each user in a training/evaluating epoch.
np.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
raise NotImplementedError('Method [get_pos_len_list] should be implemented.')

Expand Down
2 changes: 1 addition & 1 deletion recbole/data/dataloader/sequential_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _neg_sample_by_point_wise_sampling(self, data, neg_iids):
def get_pos_len_list(self):
"""
Returns:
np.ndarray or list: Number of positive item for each user in a training/evaluating epoch.
np.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
return np.ones(self.pr_end, dtype=np.int64)

Expand Down
10 changes: 5 additions & 5 deletions recbole/evaluator/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AbstractEvaluator(object):
the evaluation of the model. It is called by :class:`Trainer`.

Note:
If you want to inherit this class and implement your own evalautor class,
If you want to inherit this class and implement your own evaluator class,
you must implement the following functions.

Args:
Expand All @@ -29,18 +29,18 @@ def _check_args(self):
"""check the correct of the setting"""
raise NotImplementedError

def collect(self):
def collect(self, *args):
"""get the intermediate results for each batch, it is called at the end of each batch"""
raise NotImplementedError

def evaluate(self):
def evaluate(self, *args):
"""calculate the metrics of all batches, it is called at the end of each epoch"""
raise NotImplementedError

def metrics_info(self):
def metrics_info(self, *args):
"""get metrics result"""
raise NotImplementedError

def _calculate_metrics(self):
def _calculate_metrics(self, *args):
""" to calculate the metrics"""
raise NotImplementedError
12 changes: 8 additions & 4 deletions recbole/evaluator/loss_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class LossEvaluator(AbstractEvaluator):

Note:
The metrics used do not calculate group-based metrics which considers the metrics scores averaged across users.
They are also not limited to k. Instead, they calculate the scores on the entire prediction results regardless the users.
They are also not limited to k. Instead, they calculate the scores on the entire prediction results regardless
the users.

"""
def __init__(self, config):
Expand All @@ -46,7 +47,7 @@ def collect(self, interaction, pred_scores):
pred_scores (tensor): the tensor of model output with a size of `(N, )`

Returns:
tensor : a batch of socres with a size of `(N, 2)`
tensor : a batch of scores with a size of `(N, 2)`

"""
true_scores = interaction[self.label_field].to(pred_scores.device)
Expand Down Expand Up @@ -113,5 +114,8 @@ def _calculate_metrics(self, trues, preds):
return self.metrics_info(trues, preds)

def __str__(self):
mesg = 'The Loss Evaluator Info:\n' + '\tMetrics:[' + ', '.join([loss_metrics[metric.lower()] for metric in self.metrics]) + ']'
return mesg
msg = 'The Loss Evaluator Info:\n' + \
'\tMetrics:[' + \
', '.join([loss_metrics[metric.lower()] for metric in self.metrics]) + \
']'
return msg
23 changes: 14 additions & 9 deletions recbole/evaluator/topk_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def collect(self, interaction, scores_tensor, full=False):
scores_matrix = scores_tensor.view(len(user_len_list), -1)
else:
scores_list = torch.split(scores_tensor, user_len_list, dim=0)
scores_matrix = pad_sequence(scores_list, batch_first=True, padding_value=-np.inf) # nusers x items
scores_matrix = pad_sequence(scores_list, batch_first=True, padding_value=-np.inf) # n_users x items

# get topk
_, topk_index = torch.topk(scores_matrix, max(self.topk), dim=-1) # nusers x k
_, topk_index = torch.topk(scores_matrix, max(self.topk), dim=-1) # n_users x k

return topk_index

Expand Down Expand Up @@ -105,17 +105,18 @@ def _check_args(self):
self.topk = [self.topk]
for topk in self.topk:
if topk <= 0:
raise ValueError('topk must be a positive integer or a list of positive integers, but get `{}`'.format(topk))
raise ValueError('topk must be a positive integer or a list of positive integers, '
'but get `{}`'.format(topk))
else:
raise TypeError('The topk must be a integer, list')

def metrics_info(self, pos_idx, pos_len):
"""get metrics result

Args:
pos_idx (np.ndarray): the bool index of all users' topk items that indicating the postive items are
pos_idx (np.ndarray): the bool index of all users' topk items that indicating the positive items are
topk items or not
pos_len (list): the length of all users' postivite items
pos_len (np.ndarray): the length of all users' positive items

Returns:
list: a list of matrix which record the results from `1` to `max(topk)`
Expand All @@ -132,7 +133,7 @@ def _calculate_metrics(self, pos_len_list, topk_index):
"""integrate the results of each batch and evaluate the topk metrics by users

Args:
pos_len_list (list): a list of users' positive items
pos_len_list (np.ndarray): a list of users' positive items
topk_index (np.ndarray): a matrix which contains the index of the topk items for users

Returns:
Expand All @@ -146,6 +147,10 @@ def _calculate_metrics(self, pos_len_list, topk_index):
return result

def __str__(self):
mesg = 'The TopK Evaluator Info:\n' + '\tMetrics:[' + ', '.join([topk_metrics[metric.lower()] for metric in self.metrics]) \
+ '], TopK:[' + ', '.join(map(str, self.topk)) +']'
return mesg
msg = 'The TopK Evaluator Info:\n' + \
'\tMetrics:[' + \
', '.join([topk_metrics[metric.lower()] for metric in self.metrics]) + \
'], TopK:[' + \
', '.join(map(str, self.topk)) + \
']'
return msg
18 changes: 8 additions & 10 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(self, config, model):

self.item_tensor = None
self.tot_item_num = None
self.iid_field = config['ITEM_ID_FIELD']

def _build_optimizer(self):
r"""Init the Optimizer
Expand Down Expand Up @@ -209,13 +208,11 @@ def _check_nan(self, loss):
raise ValueError('Training loss is nan')

def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
train_loss_output = "epoch %d training [time: %.2fs, " % (epoch_idx, e_time - s_time)
train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time)
if isinstance(losses, tuple):
for idx, loss in enumerate(losses):
train_loss_output += 'train_loss%d: %.4f, ' % (idx + 1, loss)
train_loss_output = train_loss_output[:-2]
train_loss_output = ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
else:
train_loss_output += "train loss: %.4f" % losses
train_loss_output += 'train loss: %.4f' % losses
return train_loss_output + ']'

def fit(self, train_data, valid_data=None, verbose=True, saved=True):
Expand All @@ -231,8 +228,9 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
if hasattr(self.model, 'train_preparation'):
self.model.train_preparation(train_data=train_data, valid_data=valid_data)
if saved and self.start_epoch >= self.epochs:
self._save_checkpoint(-1)

for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
Expand Down Expand Up @@ -616,8 +614,8 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
if hasattr(self.model, 'train_preparation'):
self.model.train_preparation(train_data=train_data, valid_data=valid_data)
if saved and self.start_epoch >= self.epochs:
self._save_checkpoint(-1)

if self.model.train_stage == 'pretrain':
return self.pretrain(train_data, verbose)
Expand Down
6 changes: 5 additions & 1 deletion recbole/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ def get_model(model_name):
]

model_file_name = model_name.lower()
model_module = None
for submodule in model_submodule:
module_path = '.'.join(['...model', submodule, model_file_name])
module_path = '.'.join(['recbole.model', submodule, model_file_name])
if importlib.util.find_spec(module_path, __name__):
model_module = importlib.import_module(module_path, __name__)
break

if model_module is None:
raise ValueError('`model_name` [{}] is not the name of an existing model.'.format(model_name))
model_class = getattr(model_module, model_name)
return model_class

Expand Down