Skip to content

Commit

Permalink
Merge pull request #2 from ShanleiMu/master
Browse files Browse the repository at this point in the history
FEA: Init the model and trainer part.
  • Loading branch information
ShanleiMu authored Jun 27, 2020
2 parents e5e9fab + 3f7224b commit 5e99000
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 0 deletions.
45 changes: 45 additions & 0 deletions model/abstract_recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# @Time : 2020/6/25 15:47
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# @File : abstract_recommender.py

import numpy as np
import torch.nn as nn


class AbstractRecommender(nn.Module):
"""
Base class for all models
"""
def forward(self, *inputs):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def train_model(self, *inputs):
"""
Calculate Train loss
:return: Model train loss
"""
raise NotImplementedError

def predict(self, *inputs):
"""
Result prediction for testing and evaluating
:return: Model predict
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + '\nTrainable parameters: {}'.format(params)
52 changes: 52 additions & 0 deletions model/general_recommender/bprmf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# @Time : 2020/6/25 16:28
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# @File : bprmf.py

"""
Reference:
Steffen Rendle et al., "BPR: Bayesian Personalized Ranking from Implicit Feedback." in UAI 2009.
"""

import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_

from model.abstract_recommender import AbstractRecommender
from model.loss import BPRLoss


class BPRMF(AbstractRecommender):

def __init__(self, config, dataset):
super(BPRMF, self).__init__()

self.embedding_size = config['embedding_size']
self.n_users = dataset.n_users
self.n_items = dataset.n_items

self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
self.loss = BPRLoss()

self._init_weights()

def _init_weights(self):
xavier_normal_(self.user_embedding.weight)
xavier_normal_(self.item_embedding.weight)

def forward(self, user, item):
user_e = self.user_embedding(user)
item_e = self.item_embedding(item)
item_score = torch.mul(user_e, item_e).sum(dim=1)
return item_score

def train_model(self, user, pos_item, neg_item):
pos_item_score = self.forward(user, pos_item)
neg_item_score = self.forward(user, neg_item)
loss = - self.loss(pos_item_score, neg_item_score)
return loss

def predict(self, user, item):
return self.forward(user, item)
68 changes: 68 additions & 0 deletions model/general_recommender/neumf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# @Time : 2020/6/27 15:10
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# @File : neumf.py

"""
Reference:
Xiangnan He et al., "Neural Collaborative Filtering." in WWW 2017.
"""

import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_

from model.abstract_recommender import AbstractRecommender
from model.layers import MLPLayers


class NeuMF(AbstractRecommender):

def __init__(self, config, dataset):
super(NeuMF, self).__init__()

self.embedding_size = config['embedding_size']
self.layers = config['layers']
self.dropout = config['dropout']
self.n_users = dataset.n_users
self.n_items = dataset.n_items

self.user_mf_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.item_mf_embedding = nn.Embedding(self.n_items, self.embedding_size)
self.user_mlp_embedding = nn.Embedding(self.n_users, self.layers[0] // 2)
self.item_mlp_embedding = nn.Embedding(self.n_items, self.layers[0] - self.layers[0] // 2)
self.mlp_layers = MLPLayers(self.layers, self.dropout)
self.predict_layer = nn.Linear(self.embedding_size + self.layers[-1], 1)
self.loss = nn.BCEWithLogitsLoss()

self._init_weights()

def _init_weights(self):
xavier_normal_(self.user_mf_embedding.weight)
xavier_normal_(self.item_mf_embedding.weight)
xavier_normal_(self.user_mlp_embedding.weight)
xavier_normal_(self.item_mlp_embedding.weight)
xavier_normal_(self.predict_layer.weight)
for m in self.modules():
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()

def forward(self, user, item):
user_mf_e = self.user_mf_embedding(user)
item_mf_e = self.item_mf_embedding(item)
user_mlp_e = self.user_mlp_embedding(user)
item_mlp_e = self.item_mlp_embedding(item)

mf_output = torch.mul(user_mf_e, item_mf_e)
mlp_output = self.mlp_layers(torch.cat((user_mlp_e, item_mlp_e), -1))

output = self.predict_layer(torch.cat((mf_output, mlp_output), -1))
return output

def train_model(self, user, item, label):
output = self.forward(user, item)
return self.loss(output, label)

def predict(self, user, item):
return self.forward(user, item)
80 changes: 80 additions & 0 deletions model/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# @Time : 2020/6/27 16:40
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# @File : layers.py

"""
Common Layers in recommender system
"""

import warnings
import torch
import torch.nn as nn
import torch.nn.functional as fn

from torch.nn.init import xavier_normal_


class MLPLayers(nn.Module):
""" MLPLayers
Args:
- layers(list): a list contains the size of each layer in mlp layers
- dropout(float): probability of an element to be zeroed. Default: 0
- activation(str): activation function after each layer in mlp layers. Default: 'relu'
candidates: 'sigmoid', 'tanh', 'relu', 'leekyrelu', 'none'
Shape:
- Input: (N, *, H_{in}) where * means any number of additional dimensions
H_{in} must equal to the first value in `layers`
- Output: (N, *, H_{out}) where H_{out} equals to the last value in `layers`
Examples::
>> m = MLPLayers([64, 32, 16], 0.2, 'relu')
>> input = torch.randn(128, 64)
>> output = m(input)
>> print(output.size())
>> torch.Size([128, 16])
"""

def __init__(self, layers, dropout=0, activation='none'):
super(MLPLayers, self).__init__()
self.layers = layers
self.dropout = dropout
self.activation = activation

mlp_modules = []
for idx, (input_size, output_size) in enumerate(zip(self.layers[:-1], self.layers[1:])):
mlp_modules.append(nn.Dropout(p=self.dropout))
mlp_modules.append(nn.Linear(input_size, output_size))

if self.activation.lower() == 'sigmoid':
mlp_modules.append(nn.Sigmoid())
elif self.activation.lower() == 'tanh':
mlp_modules.append(nn.Tanh())
elif self.activation.lower() == 'relu':
mlp_modules.append(nn.ReLU())
elif self.activation.lower() == 'leekyrelu':
mlp_modules.append(nn.LeakyReLU())
elif self.activation.lower() == 'none':
pass
else:
warnings.warn('Received unrecognized activation function, set default activation function'
, UserWarning)

self.mlp_layers = nn.Sequential(*mlp_modules)

self._init_weights()

def _init_weights(self):
for m in self.mlp_layers:
if isinstance(m, nn.Linear):
xavier_normal_(m.weight)
for m in self.modules():
if isinstance(m, nn.Linear) and m.bias is not None:
m.bias.data.zero_()

def forward(self, input_feature):
return self.mlp_layers(input_feature)
42 changes: 42 additions & 0 deletions model/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# @Time : 2020/6/26 16:41
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# @File : loss.py

"""
Common Loss in recommender system
"""


import torch
import torch.nn as nn
import torch.nn.functional as fn


class BPRLoss(nn.Module):

""" BPRLoss, based on Bayesian Personalized Ranking
Args:
- gamma(float):
Shape:
- Pos_score: (N)
- Neg_score: (N), same shape as the Pos_score
- Output: scalar.
Examples::
>> loss = BPRLoss()
>> pos_score = torch.randn(3, requires_grad=True)
>> neg_score = torch.randn(3, requires_grad=True)
>> output = loss(pos_score, neg_score)
>> output.backward()
"""
def __init__(self, gamma=1e-10):
super(BPRLoss, self).__init__()
self.gamma = gamma

def forward(self, pos_score, neg_score):
loss = torch.log(self.gamma + torch.sigmoid(pos_score - neg_score)).mean()
return loss
Loading

0 comments on commit 5e99000

Please sign in to comment.