Skip to content

Commit

Permalink
Merge pull request #666 from hyp1231/master
Browse files Browse the repository at this point in the history
FIX: code format
  • Loading branch information
chenyushuo authored Jan 12, 2021
2 parents 9591235 + b70dcdf commit 42bf06d
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 22 deletions.
7 changes: 0 additions & 7 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE,
if self.real_time is None:
self.real_time = True

self.join = self.dataset.join
self.history_item_matrix = self.dataset.history_item_matrix
self.history_user_matrix = self.dataset.history_user_matrix
self.inter_matrix = self.dataset.inter_matrix
self.get_user_feature = self.dataset.get_user_feature
self.get_item_feature = self.dataset.get_item_feature

for dataset_attr in self.dataset._dataloader_apis:
try:
flag = hasattr(self.dataset, dataset_attr)
Expand Down
6 changes: 6 additions & 0 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ def _check_field(self, *field_names):
if getattr(self, field_name, None) is None:
raise ValueError(f'{field_name} isn\'t set.')

@dlapi.set()
def join(self, df):
"""Given interaction feature, join user/item feature into it.
Expand Down Expand Up @@ -1429,6 +1430,7 @@ def save(self, filepath):
if df is not None:
df.to_csv(os.path.join(filepath, f'{name}.csv'))

@dlapi.set()
def get_user_feature(self):
"""
Returns:
Expand All @@ -1440,6 +1442,7 @@ def get_user_feature(self):
else:
return self.user_feat

@dlapi.set()
def get_item_feature(self):
"""
Returns:
Expand Down Expand Up @@ -1536,6 +1539,7 @@ def _create_graph(self, tensor_feat, source_field, target_field, form='dgl', val
else:
raise NotImplementedError(f'Graph format [{form}] has not been implemented.')

@dlapi.set()
def inter_matrix(self, form='coo', value_field=None):
"""Get sparse matrix that describe interactions between user_id and item_id.
Expand Down Expand Up @@ -1617,6 +1621,7 @@ def _history_matrix(self, row, value_field=None):

return torch.LongTensor(history_matrix), torch.FloatTensor(history_value), torch.LongTensor(history_len)

@dlapi.set()
def history_item_matrix(self, value_field=None):
"""Get dense matrix describe user's history interaction records.
Expand All @@ -1641,6 +1646,7 @@ def history_item_matrix(self, value_field=None):
"""
return self._history_matrix(row='user', value_field=value_field)

@dlapi.set()
def history_user_matrix(self, value_field=None):
"""Get dense matrix describe item's history interaction records.
Expand Down
5 changes: 2 additions & 3 deletions recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from recbole.config import EvalSetting
from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType
from recbole.utils import ModelType, ensure_dir


def create_dataset(config):
Expand Down Expand Up @@ -216,8 +216,7 @@ def save_datasets(save_path, name, dataset):
raise ValueError(f'Length of name {name} should equal to length of dataset {dataset}.')
for i, d in enumerate(dataset):
cur_path = os.path.join(save_path, name[i])
if not os.path.isdir(cur_path):
os.makedirs(cur_path)
ensure_dir(cur_path)
d.save(cur_path)


Expand Down
2 changes: 0 additions & 2 deletions recbole/model/abstract_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(self, config, dataset):
self.n_items = dataset.num(self.ITEM_ID)

# load parameters info
self.batch_size = config['train_batch_size']
self.device = config['device']


Expand Down Expand Up @@ -145,7 +144,6 @@ def __init__(self, config, dataset):
self.n_relations = dataset.num(self.RELATION_ID)

# load parameters info
self.batch_size = config['train_batch_size']
self.device = config['device']


Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/dgcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, config, dataset):
self.n_layers = config['n_layers']
self.reg_weight = config['reg_weight']
self.cor_weight = config['cor_weight']
n_batch = dataset.dataset.inter_num // self.batch_size + 1
n_batch = dataset.dataset.inter_num // config['train_batch_size'] + 1
self.cor_batch_size = int(max(self.n_users / n_batch, self.n_items / n_batch))
# ensure embedding can be divided into <n_factors> intent
assert self.embedding_size % self.n_factors == 0
Expand Down
4 changes: 2 additions & 2 deletions recbole/model/sequential_recommender/gru4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight)
elif isinstance(module, nn.GRU):
xavier_uniform_(self.gru_layers.weight_hh_l0)
xavier_uniform_(self.gru_layers.weight_ih_l0)
xavier_uniform_(module.weight_hh_l0)
xavier_uniform_(module.weight_ih_l0)

def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
Expand Down
4 changes: 2 additions & 2 deletions recbole/model/sequential_recommender/ksr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight)
elif isinstance(module, nn.GRU):
xavier_uniform_(self.gru_layers.weight_hh_l0)
xavier_uniform_(self.gru_layers.weight_ih_l0)
xavier_uniform_(module.weight_hh_l0)
xavier_uniform_(module.weight_ih_l0)

def _get_kg_embedding(self, head):
"""Difference:
Expand Down
4 changes: 2 additions & 2 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from logging import getLogger
from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
Expand Down Expand Up @@ -64,7 +63,7 @@ class Trainer(AbstractTrainer):
Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information
for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on.
More information can be found in [placeholder]. `model` is the instantiated object of a Model Class.
`model` is the instantiated object of a Model Class.
"""

Expand Down Expand Up @@ -422,6 +421,7 @@ def plot_train_loss(self, show=True, save_path=None):
save_path (str, optional): The data path to save the figure, default: None.
If it's None, it will not be saved.
"""
import matplotlib.pyplot as plt
epochs = list(self.train_loss_dict.keys())
epochs.sort()
values = [float(self.train_loss_dict[epoch]) for epoch in epochs]
Expand Down
5 changes: 2 additions & 3 deletions recbole/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import os

from recbole.utils.utils import get_local_time
from recbole.utils.utils import get_local_time, ensure_dir


def init_logger(config):
Expand All @@ -30,8 +30,7 @@ def init_logger(config):
"""
LOGROOT = './log/'
dir_name = os.path.dirname(LOGROOT)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
ensure_dir(dir_name)

logfilename = '{}-{}.log'.format(config['model'], get_local_time())

Expand Down

0 comments on commit 42bf06d

Please sign in to comment.