Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.2.x #17

Merged
merged 21 commits into from
Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8107582
Merge pull request #8 from RUCAIBox/0.2.x
linzihan-backforward Dec 25, 2020
80779d3
Merge pull request #9 from RUCAIBox/0.2.x
linzihan-backforward Dec 27, 2020
b5e57a2
Merge pull request #625 from 2017pxy/master
chenyushuo Dec 28, 2020
3dd1f20
FEA: Raise ValueError when inter_feat have label_field and neg-sampli…
chenyushuo Dec 29, 2020
8e2f434
Merge branch '0.2.x' of github.com:RUCAIBox/RecBole into 0.2.x
chenyushuo Dec 29, 2020
6be1c87
FIX: Bug fix in test_evaluation_setting.py
chenyushuo Dec 29, 2020
290e785
FIX: bug in NeuMF
linzihan-backforward Dec 29, 2020
ce38830
Merge pull request #628 from chenyushuo/0.2.x
2017pxy Dec 29, 2020
cbd40d6
Merge pull request #629 from linzihan-backforward/master
2017pxy Dec 29, 2020
541f6d9
REFACTOR: refactor in test_model
chenyushuo Dec 29, 2020
b85005b
Merge pull request #1 from RUCAIBox/0.2.x
hyp1231 Dec 29, 2020
4f7fc9b
FIX: rm confusing error report from create_dataset
hyp1231 Dec 29, 2020
84c83f0
FIX: refactor customized dataset selection
hyp1231 Dec 29, 2020
66bd447
FIX: DCN crash when running on CPU
guijiql Dec 29, 2020
1e9f1c1
Merge pull request #631 from hyp1231/0.2.x
chenyushuo Dec 30, 2020
78370d8
Merge pull request #633 from guijiql/0.2.x
chenyushuo Dec 30, 2020
37aadba
Merge branch '0.2.x' of github.com:RUCAIBox/RecBole into 0.2.x
chenyushuo Dec 30, 2020
dfd565a
FIX: device error in BERT4Rec (from master branch)
chenyushuo Dec 30, 2020
023e12a
REFACTOR: Move tests in test_model_manual.py to test_model_auto.py
chenyushuo Dec 30, 2020
597fcf0
REFACTOR: Move test_s3rec in test_model_auto.py to test_model_manual.py
chenyushuo Dec 30, 2020
d3bcead
Merge pull request #630 from chenyushuo/0.2.x
chenyushuo Dec 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def _load_internal_config_dict(self, model, model_class, dataset):
config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
if config_dict is not None:
self.internal_config_dict.update(config_dict)
elif model in ['GRU4RecKG','KSR']:
with open(sequential_embedding_model_init, 'r', encoding='utf-8') as f:
elif model in ['GRU4RecKG', 'KSR']:
with open(sequential_embedding_model_init, 'r', encoding='utf-8') as f:
config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
if config_dict is not None:
self.internal_config_dict.update(config_dict)
Expand Down
13 changes: 10 additions & 3 deletions recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def create_dataset(config):
Returns:
Dataset: Constructed dataset.
"""
try:
return getattr(importlib.import_module('recbole.data.dataset'), config['model'] + 'Dataset')(config)
except AttributeError:
dataset_module = importlib.import_module('recbole.data.dataset')
if hasattr(dataset_module, config['model'] + 'Dataset'):
return getattr(dataset_module, config['model'] + 'Dataset')(config)
else:
model_type = config['MODEL_TYPE']
if model_type == ModelType.SEQUENTIAL:
from .dataset import SequentialDataset
Expand Down Expand Up @@ -97,6 +98,9 @@ def data_preparation(config, dataset, save=False):

kwargs = {}
if config['training_neg_sample_num']:
if dataset.label_field in dataset.inter_feat:
raise ValueError(f'`training_neg_sample_num` should be 0 '
f'if inter_feat have label_field [{dataset.label_field}].')
train_distribution = config['training_neg_sample_distribution'] or 'uniform'
es.neg_sample_by(by=config['training_neg_sample_num'], distribution=train_distribution)
if model_type != ModelType.SEQUENTIAL:
Expand All @@ -121,6 +125,9 @@ def data_preparation(config, dataset, save=False):

kwargs = {}
if len(es_str) > 1 and getattr(es, es_str[1], None):
if dataset.label_field in dataset.inter_feat:
raise ValueError(f'It can not validate with `{es_str[1]}` '
f'when inter_feat have label_field [{dataset.label_field}].')
getattr(es, es_str[1])()
if 'sampler' not in locals():
sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
Expand Down
11 changes: 4 additions & 7 deletions recbole/model/context_aware_recommender/dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ def __init__(self, config, dataset):

# define layers and loss
# init weight and bias of each cross layer
self.cross_layer_parameter = [nn.Parameter(torch.empty(self.num_feature_field * self.embedding_size,
device=self.device))
for _ in range(self.cross_layer_num * 2)]
self.cross_layer_w = nn.ParameterList(
self.cross_layer_parameter[:self.cross_layer_num])
self.cross_layer_b = nn.ParameterList(
self.cross_layer_parameter[self.cross_layer_num:])
self.cross_layer_w = nn.ParameterList(nn.Parameter(torch.randn(self.num_feature_field * self.embedding_size)
.to(self.device)) for _ in range(self.cross_layer_num))
self.cross_layer_b = nn.ParameterList(nn.Parameter(torch.zeros(self.num_feature_field * self.embedding_size)
.to(self.device)) for _ in range(self.cross_layer_num))

# size of mlp hidden layer
size_list = [self.embedding_size * self.num_feature_field] + self.mlp_hidden_size
Expand Down
4 changes: 2 additions & 2 deletions recbole/model/general_recommender/neumf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self, config, dataset):
self.item_mf_embedding = nn.Embedding(self.n_items, self.mf_embedding_size)
self.user_mlp_embedding = nn.Embedding(self.n_users, self.mlp_embedding_size)
self.item_mlp_embedding = nn.Embedding(self.n_items, self.mlp_embedding_size)
self.mlp_layers = MLPLayers([2 * self.mlp_embedding_size] + self.mlp_hidden_size)
self.mlp_layers = MLPLayers([2 * self.mlp_embedding_size] + self.mlp_hidden_size, self.dropout_prob)
self.mlp_layers.logger = None # remove logger to use torch.save()
if self.mf_train and self.mlp_train:
self.predict_layer = nn.Linear(self.mf_embedding_size + self.mlp_hidden_size[-1], 1, self.dropout_prob)
self.predict_layer = nn.Linear(self.mf_embedding_size + self.mlp_hidden_size[-1], 1)
elif self.mf_train:
self.predict_layer = nn.Linear(self.mf_embedding_size, 1)
elif self.mlp_train:
Expand Down
14 changes: 7 additions & 7 deletions recbole/model/sequential_recommender/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self, config, dataset):
self.mask_item_length = int(self.mask_ratio * self.max_seq_length)

# define layers and loss
self.item_embedding = nn.Embedding(self.n_items+1, self.hidden_size, padding_idx=0) # mask token add 1
self.position_embedding = nn.Embedding(self.max_seq_length+1, self.hidden_size) # add mask_token at the last
self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0) # mask token add 1
self.position_embedding = nn.Embedding(self.max_seq_length + 1, self.hidden_size) # add mask_token at the last
self.trm_encoder = TransformerEncoder(n_layers=self.n_layers, n_heads=self.n_heads,
hidden_size=self.hidden_size, inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
Expand Down Expand Up @@ -99,8 +99,8 @@ def _neg_sample(self, item_set):

def _padding_sequence(self, sequence, max_length):
pad_len = max_length - len(sequence)
sequence = [0]*pad_len + sequence
sequence = sequence[-max_length:] # truncate according to the max_length
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:] # truncate according to the max_length
return sequence

def reconstruct_train_data(self, item_seq):
Expand Down Expand Up @@ -193,7 +193,7 @@ def multi_hot_embed(self, masked_index, max_length):
multi_hot_embed: [[0 1 0 0 0], [0 0 0 1 0]]
"""
masked_index = masked_index.view(-1)
multi_hot = torch.zeros(masked_index.size(0), max_length).cuda()
multi_hot = torch.zeros(masked_index.size(0), max_length, device=masked_index.device)
multi_hot[torch.arange(masked_index.size(0)), masked_index] = 1
return multi_hot

Expand Down Expand Up @@ -237,7 +237,7 @@ def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
item_seq = self.reconstruct_test_data(item_seq, item_seq_len)
seq_output = self.forward(item_seq)
seq_output = self.gather_indexes(seq_output, item_seq_len-1) # [B H]
seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H]
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
Expand All @@ -247,7 +247,7 @@ def full_sort_predict(self, interaction):
item_seq_len = interaction[self.ITEM_SEQ_LEN]
item_seq = self.reconstruct_test_data(item_seq, item_seq_len)
seq_output = self.forward(item_seq)
seq_output = self.gather_indexes(seq_output, item_seq_len-1) # [B H]
seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H]
test_items_emb = self.item_embedding.weight[:self.n_items] # delete masked token
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, item_num]
return scores
7 changes: 5 additions & 2 deletions tests/evaluation_setting/test_evaluation_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_rols_full(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
'''

def test_tols_full(self):
config_dict = {
'eval_setting': 'TO_LS,full',
Expand Down Expand Up @@ -72,6 +73,7 @@ def test_tols_full(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
'''

def test_tors_full(self):
config_dict = {
'eval_setting': 'TO_RS,full',
Expand Down Expand Up @@ -182,7 +184,7 @@ def test_tors_uni100(self):
'model': 'BPR',
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
config_file_list=config_file_list, saved=False)
# config_dict = {
# 'eval_setting': 'TO_RS,uni100',
# 'model': 'NeuMF',
Expand All @@ -208,10 +210,11 @@ class TestContextRecommender(unittest.TestCase):
def test_tors(self):
config_dict = {
'eval_setting': 'TO_RS',
'threshold': {'rating': 4},
'model': 'FM',
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
config_file_list=config_file_list, saved=False)
# config_dict = {
# 'eval_setting': 'TO_RS',
# 'model': 'DeepFM',
Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ENTITY_ID_FIELD: entity_id

# Selectively Loading
load_col:
inter: [user_id, item_id, rating, timestamp, label]
inter: [user_id, item_id, rating, timestamp]
user: [user_id, age, gender, occupation]
item: [item_id, movie_title, release_year, class]
link: [item_id, entity_id]
Expand Down
Loading