Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Add ENMF model #643

Merged
merged 9 commits into from
Jan 10, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def get_data_loader(name, config, eval_setting):
register_table = {
'DIN': _get_DIN_data_loader,
"MultiDAE": _get_AE_data_loader,
"MultiVAE": _get_AE_data_loader
"MultiVAE": _get_AE_data_loader,
"ENMF": _get_AE_data_loader
}

if config['model'] in register_table:
Expand Down
1 change: 1 addition & 0 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from recbole.model.general_recommender.line import LINE
from recbole.model.general_recommender.multidae import MultiDAE
from recbole.model.general_recommender.multivae import MultiVAE
from recbole.model.general_recommender.enmf import ENMF
from recbole.model.general_recommender.nais import NAIS
from recbole.model.general_recommender.neumf import NeuMF
from recbole.model.general_recommender.ngcf import NGCF
Expand Down
123 changes: 123 additions & 0 deletions recbole/model/general_recommender/enmf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -*- coding: utf-8 -*-
# @Time : 2020/12/31
# @Author : Zihan Lin
# @Email : zhlin@ruc.edu.cn

r"""
MultiDAE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiDAE -> ENMF

################################################
Reference:
Chong Chen et al. "Efficient Neural Matrix Factorization without Sampling for Recommendation." in TOIS 2020.

Reference code:
https://github.com/chenchongthu/ENMF
"""

import torch
import torch.nn as nn
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender


class ENMF(GeneralRecommender):

input_type = InputType.POINTWISE

def __init__(self, config, dataset):
super(ENMF, self).__init__(config, dataset)

self.embedding_dim = config['embedding_dim']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding_dim -> embedding_size

self.dropout_prob = config['dropout_prob']
self.reg_weight = config['reg_weight']
self.negative_weight = config['negative_weight']

# get all users' history interaction information.
# matrix is padding by the maximum number of a user's interactions
self.history_item_matrix, _, self.history_lens = dataset.history_item_matrix()
self.history_item_matrix = self.history_item_matrix.to(self.device)

self.user_embedding = nn.Embedding(self.n_users, self.embedding_dim, padding_idx=0)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_dim, padding_idx=0)
self.H_i = nn.Linear(self.embedding_dim, 1, bias=False)
self.dropout = nn.Dropout(self.dropout_prob)

self.apply(xavier_normal_initialization)

def reg_loss(self):
"""calculate the reg loss for embedding layers and mlp layers

Returns:
torch.Tensor: reg loss

"""
l2_reg = self.user_embedding.weight.norm(2) + self.item_embedding.weight.norm(2)
loss_l2 = self.reg_weight * l2_reg

return loss_l2

def forward(self, user):
user_embedding = self.user_embedding(user) # shape:[B, embedding_dim]
user_embedding = self.dropout(user_embedding) # shape:[B, embedding_dim]

user_inter = self.history_item_matrix[user] # shape :[B, max_len]
item_embedding = self.item_embedding(user_inter) # shape: [B, max_len, embedding_dim]
score = torch.mul(user_embedding.unsqueeze(1), item_embedding) # shape: [B, max_len, embedding_dim]
score = self.H_i(score) # shape: [B,max_len,1]
score = score.squeeze() # shape:[B,max_len]

return score

def calculate_loss(self, interaction):
user = interaction[self.USER_ID]

pos_score = self.forward(user)

# shape: [embedding_dim, embedding_dim]
item_sum = torch.bmm(self.item_embedding.weight.unsqueeze(2), self.item_embedding.weight.unsqueeze(1)).sum(dim=0)

# shape: [embedding_dim, embedding_dim]
user_sum = torch.bmm(self.user_embedding.weight.unsqueeze(2), self.user_embedding.weight.unsqueeze(1)).sum(dim=0)

# shape: [embedding_dim, embedding_dim]
H_sum = torch.matmul(self.H_i.weight.t(), self.H_i.weight)

t = torch.sum(item_sum * user_sum * H_sum)

loss = self.negative_weight * t

loss = loss + torch.sum((1-self.negative_weight) * torch.square(pos_score) - 2 * pos_score)

loss = loss + self.reg_loss()

return loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

u_e = self.user_embedding(user)
i_e = self.item_embedding(item)

score = torch.mul(u_e, i_e) # shape: [B,embedding_dim]
score = self.H_i(score) # shape: [B,1]

return score.squeeze(1)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]

u_e = self.user_embedding(user) # shape: [B,embedding_dim]

all_i_e = self.item_embedding.weight # shape: [n_item,embedding_dim]

score = torch.mul(u_e.unsqueeze(1), all_i_e.unsqueeze(0)) # shape: [B, n_item, embedding_dim]

score = self.H_i(score).squeeze(2) # shape: [B, n_item]

return score.view(-1)





5 changes: 5 additions & 0 deletions recbole/properties/model/ENMF.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
embedding_dim: 64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding_dim -> embedding_size

dropout_prob: 0.7
reg_weight: 0.0
negative_weight: 0.5
training_neg_sample_num: 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training_neg_sample_num has been declared in overall.yaml

6 changes: 6 additions & 0 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def test_MultiVAE(self):
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)

def test_enmf(self):
config_dict = {
'model': 'ENMF',
}
quick_test(config_dict)


class TestContextRecommender(unittest.TestCase):
Expand Down