Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename EvaluatorType & FIX: bugs in model #937

Merged
merged 4 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from logging import getLogger

from recbole.evaluator import group_metrics, individual_metrics
from recbole.evaluator import rank_metrics, value_metrics
from recbole.utils import get_model, Enum, EvaluatorType, ModelType, InputType, \
general_arguments, training_arguments, evaluation_arguments, dataset_arguments, set_color

Expand Down Expand Up @@ -291,13 +291,13 @@ def _set_default_parameters(self):

eval_type = None
for metric in self.final_config_dict['metrics']:
if metric.lower() in individual_metrics:
if metric.lower() in value_metrics:
if eval_type is not None and eval_type == EvaluatorType.RANKING:
raise RuntimeError('Ranking metrics and other metrics can not be used at the same time.')
else:
eval_type = EvaluatorType.INDIVIDUAL
if metric.lower() in group_metrics:
if eval_type is not None and eval_type == EvaluatorType.INDIVIDUAL:
eval_type = EvaluatorType.VALUE
if metric.lower() in rank_metrics:
if eval_type is not None and eval_type == EvaluatorType.VALUE:
raise RuntimeError('Ranking metrics and other metrics can not be used at the same time.')
else:
eval_type = EvaluatorType.RANKING
Expand Down Expand Up @@ -394,7 +394,11 @@ def __setitem__(self, key, value):
self.final_config_dict[key] = value

def __getattr__(self, item):
return self.__getitem__(item)
if 'final_config_dict' not in self.__dict__:
raise AttributeError(f"'Config' object has no attribute 'final_config_dict'")
if item in self.final_config_dict:
return self.final_config_dict[item]
raise AttributeError(f"'Config' object has no attribute '{item}'")

def __getitem__(self, item):
if item in self.final_config_dict:
Expand Down
4 changes: 2 additions & 2 deletions recbole/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from recbole.evaluator.metrics import metrics_dict
from recbole.evaluator.collector import DataStruct
from recbole.evaluator.register import loss_metrics
from recbole.evaluator.register import value_metrics


class Evaluator(object):
Expand Down Expand Up @@ -44,7 +44,7 @@ def evaluate(self, dataobject: DataStruct):

def _check_args(self):
# Check Loss
if set(self.metrics) & set(loss_metrics):
if set(self.metrics) & set(value_metrics):
is_full = 'full' in self.config['eval_args']['mode']
if is_full:
raise NotImplementedError('Full sort evaluation do not match the metrics!')
15 changes: 4 additions & 11 deletions recbole/evaluator/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,11 @@
'rmse': ['rec.score', 'data.label'],
'mae': ['rec.score', 'data.label'],
'logloss': ['rec.score', 'data.label']}
# These metrics are typical in top-k recommendations
topk_metrics = {metric.lower(): metric for metric in ['Hit', 'Recall', 'MRR', 'Precision', 'NDCG', 'MAP',
# These metrics are typical in ranking-based recommendations
rank_metrics = {metric.lower(): metric for metric in ['Hit', 'Recall', 'MRR', 'Precision', 'NDCG', 'MAP', 'GAUC'
'ItemCoverage', 'AveragePopularity', 'ShannonEntropy', 'GiniIndex']}
# These metrics are typical in loss recommendations
loss_metrics = {metric.lower(): metric for metric in ['AUC', 'RMSE', 'MAE', 'LOGLOSS']}
# For GAUC
rank_metrics = {metric.lower(): metric for metric in ['GAUC']}

# group-based metrics
group_metrics = ChainMap(topk_metrics, rank_metrics)
# not group-based metrics
individual_metrics = ChainMap(loss_metrics)
# These metrics are typical in value-based recommendations
value_metrics = {metric.lower(): metric for metric in ['AUC', 'RMSE', 'MAE', 'LOGLOSS']}


class Register(object):
Expand Down
11 changes: 5 additions & 6 deletions recbole/model/context_aware_recommender/xdeepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config, dataset):
)

# Create a convolutional layer for each CIN layer
self.conv1d_list = []
self.conv1d_list = nn.ModuleList()
self.field_nums = [self.num_feature_field]
for i, layer_size in enumerate(self.cin_layer_size):
conv1d = nn.Conv1d(self.field_nums[-1] * self.field_nums[0], layer_size, 1).to(self.device)
Expand All @@ -74,14 +74,13 @@ def __init__(self, config, dataset):
else:
self.final_len = sum(self.cin_layer_size[:-1]) // 2 + self.cin_layer_size[-1]

self.cin_linear = nn.Linear(self.final_len, 1, bias=False)
self.cin_linear = nn.Linear(self.final_len, 1)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.apply(self._init_weights)
self.other_parameter_name = ['conv1d_list']

def _init_weights(self, module):
if isinstance(module, nn.Embedding):
if isinstance(module, nn.Embedding) or isinstance(module, nn.Conv1d):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
Expand Down Expand Up @@ -114,7 +113,7 @@ def calculate_reg_loss(self):
l2_reg += self.reg_loss(conv1d.named_parameters())
return l2_reg

def compressed_interaction_network(self, input_features, activation='identity'):
def compressed_interaction_network(self, input_features, activation='ReLU'):
r"""For k-th CIN layer, the output :math:`X_k` is calculated via

.. math::
Expand All @@ -138,7 +137,7 @@ def compressed_interaction_network(self, input_features, activation='identity'):
hidden_nn_layers = [input_features]
final_result = []
for i, layer_size in enumerate(self.cin_layer_size):
z_i = torch.einsum('bmd,bhd->bhmd', hidden_nn_layers[0], hidden_nn_layers[-1])
z_i = torch.einsum('bhd,bmd->bhmd', hidden_nn_layers[-1], hidden_nn_layers[0])
z_i = z_i.view(batch_size, self.field_nums[0] * self.field_nums[i], embedding_size)
z_i = self.conv1d_list[i](z_i)

Expand Down
2 changes: 1 addition & 1 deletion recbole/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def embed_input_fields(self, user_idx, item_idx):
feature = user_item_feat[type][field_name][user_item_idx[type]]
float_fields.append(feature if len(feature.shape) == (2 + (type == 'item')) else feature.unsqueeze(-1))
if len(float_fields) > 0:
float_fields = torch.cat(float_fields, dim=1) # [batch_size, max_item_length, num_float_field]
float_fields = torch.cat(float_fields, dim=-1) # [batch_size, max_item_length, num_float_field]
else:
float_fields = None
# [batch_size, max_item_length, num_float_field]
Expand Down
2 changes: 1 addition & 1 deletion recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _neg_sample_batch_eval(self, batched_data):
else:
origin_scores = self._spilt_predict(interaction, batch_size)

if self.config['eval_type'] == EvaluatorType.INDIVIDUAL:
if self.config['eval_type'] == EvaluatorType.VALUE:
return interaction, origin_scores, positive_u, positive_i
elif self.config['eval_type'] == EvaluatorType.RANKING:
col_idx = interaction[self.config['ITEM_ID_FIELD']]
Expand Down
6 changes: 3 additions & 3 deletions recbole/utils/enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class KGDataLoaderState(Enum):
class EvaluatorType(Enum):
"""Type for evaluation metrics.

- ``RANKING``: Ranking metrics like NDCG, Recall, etc.
- ``INDIVIDUAL``: Individual metrics like AUC, etc.
- ``RANKING``: Ranking-based metrics like NDCG, Recall, etc.
- ``VALUE``: Value-based metrics like AUC, etc.
"""

RANKING = 1
INDIVIDUAL = 2
VALUE = 2


class InputType(Enum):
Expand Down