Skip to content

Commit

Permalink
Merge pull request #10 from RUCAIBox/master
Browse files Browse the repository at this point in the history
update
  • Loading branch information
2017pxy authored Aug 26, 2020
2 parents 5de3955 + b909aa2 commit fadadad
Show file tree
Hide file tree
Showing 64 changed files with 1,361 additions and 740 deletions.
2 changes: 0 additions & 2 deletions config/__init__.py

This file was deleted.

84 changes: 0 additions & 84 deletions data/interaction.py

This file was deleted.

19 changes: 10 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from config import Config
from data import Dataset, data_preparation
from model.general_recommender.bprmf import BPRMF
from trainer import Trainer
from utils import init_logger
from logging import getLogger
from recbox.config import Config
from recbox.data import Dataset, data_preparation
from recbox.model.general_recommender.bprmf import BPRMF
from recbox.trainer import Trainer
from recbox.utils import init_logger, get_model

config = Config('properties/overall.config')
config.init()
init_logger(config)
logger = getLogger()

dataset = Dataset(config)
logger.info(dataset)

model = BPRMF(config, dataset).to(config['device'])
logger.info(model)

# If you want to customize the evaluation setting,
# please refer to `data_preparation()` in `data/utils.py`.
train_data, test_data, valid_data = data_preparation(config, model, dataset)
train_data, test_data, valid_data = data_preparation(config, dataset)

model = get_model(config['model'])(config, train_data).to(config['device'])
logger.info(model)

trainer = Trainer(config, model)

Expand Down
80 changes: 0 additions & 80 deletions model/sequential_recommender/gru4rec.py

This file was deleted.

5 changes: 5 additions & 0 deletions properties/model/DMF.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[model]

layers = [64,64]
reg_weight = 0.0
min_y_hat = 1e-6
10 changes: 8 additions & 2 deletions properties/model/GRU4Rec.config
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
[model]

MAX_ITEM_LIST_LENGTH=50
TARGET_PREFIX='target_'
LIST_SUFFIX='_list'
POSITION_FIELD='position_id'
ITEM_LIST_LENGTH_FIELD='item_length'

embedding_size=64
num_layers=3
dropout=0.1
num_layers=1
dropout=0
3 changes: 3 additions & 0 deletions properties/model/ItemKNN.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[model]

k=100
2 changes: 1 addition & 1 deletion properties/model/NGCF.config
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

embedding_size=64
layers=[64,64,64]
node_dropout=0.1
node_dropout=0.0
message_dropout=0.1
delay=1e-5
13 changes: 13 additions & 0 deletions properties/model/SASRec.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[model]

MAX_ITEM_LIST_LENGTH=50
TARGET_PREFIX='target_'
LIST_SUFFIX='_list'
POSITION_FIELD='position_id'
ITEM_LIST_LENGTH_FIELD='item_length'

embedding_size=64
n_head=1
d_ff=128
dropout=0
num_blocks=1
2 changes: 2 additions & 0 deletions recbox/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .configurator import Config
from .eval_setting import EvalSetting
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from utils.enum_type import *
from configparser import ConfigParser
from ..utils.enum_type import *


class AbstractConfig(object):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from config.abstract_configurator import AbstractConfig
from .abstract_configurator import AbstractConfig


class CmdConfig(AbstractConfig):
Expand Down
15 changes: 9 additions & 6 deletions config/configurator.py → recbox/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import random
import sys
import numpy as np
from config.running_configurator import RunningConfig
from config.model_configurator import ModelConfig
from config.data_configurator import DataConfig
from config.cmd_configurator import CmdConfig
from .running_configurator import RunningConfig
from .model_configurator import ModelConfig
from .data_configurator import DataConfig
from .cmd_configurator import CmdConfig


class Config(object):
Expand Down Expand Up @@ -61,24 +61,27 @@ def __init__(self, config_file_name, config_dict=None):
self.dataset_args = DataConfig(dataset_arg_file_name, self.cmd_args_dict)

self.device = None
self._init_device()

def init(self):
def _init_device(self):
"""
This function is a global initialization function that fix random seed and gpu device.
"""
init_seed = self.run_args['seed']
use_gpu = self.run_args['use_gpu']
if use_gpu:
gpu_id = self.run_args['gpu_id']
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# Get the device that run on.
self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")

def init(self):
init_seed = self.run_args['seed']
random.seed(init_seed)
np.random.seed(init_seed)
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
torch.cuda.manual_seed_all(init_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def _read_cmd_line(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from config.abstract_configurator import AbstractConfig
from .abstract_configurator import AbstractConfig


class DataConfig(AbstractConfig):
Expand Down
7 changes: 5 additions & 2 deletions config/eval_setting.py → recbox/config/eval_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/8/13 11:11
# @Time : 2020/8/19 18:56
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn


class EvalSetting(object):
def __init__(self, config):
self.config = config
Expand Down Expand Up @@ -76,14 +77,16 @@ def random_ordering(self):
self.set_ordering('shuffle')

def sort_by(self, field, ascending=None):
if not isinstance(field, list):
field = [field]
if ascending is None:
ascending = [True] * len(field)
if len(ascending) == 1:
ascending = True
self.set_ordering('by', field=field, ascending=ascending)

def temporal_ordering(self):
self.sort_by(field=self.config['TIMESTAMP_FIELD'])
self.sort_by(field=self.config['TIME_FIELD'])

r"""Setting about split method
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from config.abstract_configurator import AbstractConfig
from .abstract_configurator import AbstractConfig


class ModelConfig(AbstractConfig):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from config.abstract_configurator import AbstractConfig
import os
from evaluator import loss_metrics, topk_metrics
from utils import EvaluatorType
from ..evaluator import loss_metrics, topk_metrics
from ..utils import EvaluatorType, get_model
from .abstract_configurator import AbstractConfig


class RunningConfig(AbstractConfig):
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, config_file_name, cmd_args):
eval_type = EvaluatorType.RANKING
self['eval_type'] = eval_type

smaller_metric = ['rmse','mae', 'logloss']
smaller_metric = ['rmse', 'mae', 'logloss']

if 'valid_metric' not in self:
valid_metric = self['metric'][0]
Expand All @@ -54,6 +54,6 @@ def __init__(self, config_file_name, cmd_args):
else:
self['valid_metric_bigger'] = True




model = get_model(self['model'])
self['MODEL_TYPE'] = model.type
self['MODEL_INPUT_TYPE'] = model.input_type
File renamed without changes.
Loading

0 comments on commit fadadad

Please sign in to comment.