Skip to content

Commit

Permalink
Merge pull request #618 from chenyushuo/0.2.x
Browse files Browse the repository at this point in the history
FEA: Add three new features
  • Loading branch information
2017pxy authored Dec 25, 2020
2 parents b936a51 + de64410 commit b2d8664
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 49 deletions.
2 changes: 2 additions & 0 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ state: INFO
reproducibility: True
data_path: 'dataset/'
checkpoint_dir: 'saved'
show_progress: True

# training settings
epochs: 300
train_batch_size: 2048
learner: adam
learning_rate: 0.001
training_neg_sample_num: 1
training_neg_sample_distribution: uniform
eval_step: 1
stopping_step: 10
clip_grad_norm: ~
Expand Down
5 changes: 3 additions & 2 deletions recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=Non
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

# model training
best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, saved=saved)
best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, saved=saved,
show_progress=config['show_progress'])

# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=saved)
test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'])

logger.info('best valid result: {}'.format(best_valid_result))
logger.info('test result: {}'.format(test_result))
Expand Down
128 changes: 82 additions & 46 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

import os
import itertools
from tqdm import tqdm
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
Expand Down Expand Up @@ -119,14 +119,15 @@ def _build_optimizer(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
return optimizer

def _train_epoch(self, train_data, epoch_idx, loss_func=None):
def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
r"""Train the model in an epoch
Args:
train_data (DataLoader): The train data.
epoch_idx (int): The current epoch id.
loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
:attr:`self.model.calculate_loss`. Defaults to ``None``.
show_progress (bool): Show progress of epoch training. Defaults to ``False``.
Returns:
float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
Expand All @@ -136,7 +137,16 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None):
self.model.train()
loss_func = loss_func or self.model.calculate_loss
total_loss = None
for batch_idx, interaction in enumerate(train_data):
iter_data = (
tqdm(
enumerate(train_data),
total=len(train_data),
desc=f"Train {epoch_idx:>5}",
)
if show_progress
else enumerate(train_data)
)
for batch_idx, interaction in iter_data:
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
losses = loss_func(interaction)
Expand All @@ -154,17 +164,18 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None):
self.optimizer.step()
return total_loss

def _valid_epoch(self, valid_data):
def _valid_epoch(self, valid_data, show_progress=False):
r"""Valid the model with valid data
Args:
valid_data (DataLoader): the valid data
valid_data (DataLoader): the valid data.
show_progress (bool): Show progress of epoch evaluate. Defaults to ``False``.
Returns:
float: valid score
dict: valid result
"""
valid_result = self.evaluate(valid_data, load_best_model=False)
valid_result = self.evaluate(valid_data, load_best_model=False, show_progress=show_progress)
valid_score = calculate_valid_score(valid_result, self.valid_metric)
return valid_score, valid_result

Expand Down Expand Up @@ -221,7 +232,7 @@ def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
train_loss_output += 'train loss: %.4f' % losses
return train_loss_output + ']'

def fit(self, train_data, valid_data=None, verbose=True, saved=True):
def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
r"""Train the model based on the train data and the valid data.
Args:
Expand All @@ -230,6 +241,9 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
If it's None, the early_stopping is invalid.
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
saved (bool, optional): whether to save the model parameters, default: True
show_progress (bool): Show progress of epoch training and evaluate. Defaults to ``False``.
callback_fn (callable): Optional callback function executed at end of epoch.
Includes (epoch_idx, valid_score) input arguments.
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
Expand All @@ -240,7 +254,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
train_loss = self._train_epoch(train_data, epoch_idx)
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
Expand All @@ -258,7 +272,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
continue
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(valid_data)
valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score, self.best_valid_score, self.cur_step,
max_step=self.stopping_step, bigger=self.valid_metric_bigger)
Expand All @@ -277,6 +291,9 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
self.logger.info(update_output)
self.best_valid_result = valid_result

if callback_fn:
callback_fn(epoch_idx, valid_score)

if stop_flag:
stop_output = 'Finished training, best eval result in epoch %d' % \
(epoch_idx - self.cur_step * self.eval_step)
Expand Down Expand Up @@ -312,7 +329,7 @@ def _full_sort_batch_eval(self, batched_data):
return interaction, scores

@torch.no_grad()
def evaluate(self, eval_data, load_best_model=True, model_file=None):
def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progress=False):
r"""Evaluate the model based on the eval data.
Args:
Expand All @@ -321,6 +338,7 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None):
It should be set True, if users want to test the model after training.
model_file (str, optional): the saved model file, default: None. If users want to test the previously
trained model file, they can set this parameter.
show_progress (bool): Show progress of epoch evaluate. Defaults to ``False``.
Returns:
dict: eval result, key is the eval metric and value in the corresponding metric value
Expand All @@ -346,7 +364,16 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None):
self.tot_item_num = eval_data.dataset.item_num

batch_matrix_list = []
for batch_idx, batched_data in enumerate(eval_data):
iter_data = (
tqdm(
enumerate(eval_data),
total=len(eval_data),
desc=f"Evaluate ",
)
if show_progress
else enumerate(eval_data)
)
for batch_idx, batched_data in iter_data:
if eval_data.dl_type == DataLoaderType.FULL:
interaction, scores = self._full_sort_batch_eval(batched_data)
else:
Expand Down Expand Up @@ -412,7 +439,7 @@ def __init__(self, config, model):
self.train_rec_step = config['train_rec_step']
self.train_kg_step = config['train_kg_step']

def _train_epoch(self, train_data, epoch_idx, loss_func=None):
def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
if self.train_rec_step is None or self.train_kg_step is None:
interaction_state = KGDataLoaderState.RSKG
elif epoch_idx % (self.train_rec_step + self.train_kg_step) < self.train_rec_step:
Expand All @@ -421,9 +448,11 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None):
interaction_state = KGDataLoaderState.KG
train_data.set_mode(interaction_state)
if interaction_state in [KGDataLoaderState.RSKG, KGDataLoaderState.RS]:
return super()._train_epoch(train_data, epoch_idx)
return super()._train_epoch(train_data, epoch_idx, show_progress=show_progress)
elif interaction_state in [KGDataLoaderState.KG]:
return super()._train_epoch(train_data, epoch_idx, self.model.calculate_kg_loss)
return super()._train_epoch(train_data, epoch_idx,
loss_func=self.model.calculate_kg_loss,
show_progress=show_progress)
return None


Expand All @@ -435,14 +464,16 @@ class KGATTrainer(Trainer):
def __init__(self, config, model):
super(KGATTrainer, self).__init__(config, model)

def _train_epoch(self, train_data, epoch_idx, loss_func=None):
def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
# train rs
train_data.set_mode(KGDataLoaderState.RS)
rs_total_loss = super()._train_epoch(train_data, epoch_idx)
rs_total_loss = super()._train_epoch(train_data, epoch_idx, show_progress=show_progress)

# train kg
train_data.set_mode(KGDataLoaderState.KG)
kg_total_loss = super()._train_epoch(train_data, epoch_idx, self.model.calculate_kg_loss)
kg_total_loss = super()._train_epoch(train_data, epoch_idx,
loss_func=self.model.calculate_kg_loss,
show_progress=show_progress)

# update A
self.model.eval()
Expand Down Expand Up @@ -477,12 +508,12 @@ def save_pretrained_model(self, epoch, saved_model_file):
}
torch.save(state, saved_model_file)

def pretrain(self, train_data, verbose=True):
def pretrain(self, train_data, verbose=True, show_progress=False):

for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
train_loss = self._train_epoch(train_data, epoch_idx)
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
Expand All @@ -501,11 +532,11 @@ def pretrain(self, train_data, verbose=True):

return self.best_valid_score, self.best_valid_result

def fit(self, train_data, valid_data=None, verbose=True, saved=True):
def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
if self.model.train_stage == 'pretrain':
return self.pretrain(train_data, verbose)
return self.pretrain(train_data, verbose, show_progress)
elif self.model.train_stage == 'finetune':
return super().fit(train_data, valid_data, verbose, saved)
return super().fit(train_data, valid_data, verbose, saved, show_progress, callback_fn)
else:
raise ValueError("Please make sure that the 'train_stage' is 'pretrain' or 'finetune' ")

Expand All @@ -519,19 +550,23 @@ def __init__(self, config, model):
super(MKRTrainer, self).__init__(config, model)
self.kge_interval = config['kge_interval']

def _train_epoch(self, train_data, epoch_idx, loss_func=None):
def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
rs_total_loss, kg_total_loss = 0., 0.

# train rs
self.logger.info('Train RS')
train_data.set_mode(KGDataLoaderState.RS)
rs_total_loss = super()._train_epoch(train_data, epoch_idx, self.model.calculate_rs_loss)
rs_total_loss = super()._train_epoch(train_data, epoch_idx,
loss_func=self.model.calculate_rs_loss,
show_progress=show_progress)

# train kg
if epoch_idx % self.kge_interval == 0:
self.logger.info('Train KG')
train_data.set_mode(KGDataLoaderState.KG)
kg_total_loss = super()._train_epoch(train_data, epoch_idx, self.model.calculate_kg_loss)
kg_total_loss = super()._train_epoch(train_data, epoch_idx,
loss_func=self.model.calculate_kg_loss,
show_progress=show_progress)

return rs_total_loss, kg_total_loss

Expand All @@ -550,9 +585,10 @@ class xgboostTrainer(AbstractTrainer):
"""xgboostTrainer is designed for XGBOOST.
"""

def __init__(self, config, model):
super(xgboostTrainer, self).__init__(config, model)

self.xgb = __import__('xgboost')

self.logger = getLogger()
Expand Down Expand Up @@ -605,22 +641,22 @@ def _interaction_to_DMatrix(self, interaction):
interaction_np = interaction.numpy()
cur_data = np.array([])
for key, value in interaction_np.items():
value = np.resize(value,(value.shape[0],1))
value = np.resize(value, (value.shape[0], 1))
if key != self.label_field:
if cur_data.shape[0] == 0:
cur_data = value
else:
cur_data = np.hstack((cur_data, value))
return self.xgb.DMatrix(data = cur_data,
label = interaction_np[self.label_field],
weight = self.weight,
base_margin = self.base_margin,
missing = self.missing,
silent = self.silent,
feature_names = self.feature_names,
feature_types = self.feature_types,
nthread = self.nthread)

return self.xgb.DMatrix(data=cur_data,
label=interaction_np[self.label_field],
weight=self.weight,
base_margin=self.base_margin,
missing=self.missing,
silent=self.silent,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=self.nthread)

def _train_at_once(self, train_data, valid_data):
r"""
Expand All @@ -631,11 +667,11 @@ def _train_at_once(self, train_data, valid_data):
"""
self.dtrain = self._interaction_to_DMatrix(train_data.dataset[:])
self.dvalid = self._interaction_to_DMatrix(valid_data.dataset[:])
self.evals = [(self.dtrain,'train'),(self.dvalid, 'valid')]
self.model = self.xgb.train(self.params, self.dtrain, self.num_boost_round,
self.evals, self.obj, self.feval, self.maximize,
self.early_stopping_rounds, self.evals_result,
self.verbose_eval, self.xgb_model, self.callbacks)
self.evals = [(self.dtrain, 'train'), (self.dvalid, 'valid')]
self.model = self.xgb.train(self.params, self.dtrain, self.num_boost_round,
self.evals, self.obj, self.feval, self.maximize,
self.early_stopping_rounds, self.evals_result,
self.verbose_eval, self.xgb_model, self.callbacks)
self.model.save_model(self.saved_model_file)
self.xgb_model = self.saved_model_file

Expand All @@ -645,13 +681,13 @@ def _valid_epoch(self, valid_data):
Args:
valid_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
"""
valid_result = self.evaluate(valid_data)
valid_result = self.evaluate(valid_data)
valid_score = calculate_valid_score(valid_result, self.valid_metric)
return valid_result, valid_score

def fit(self, train_data, valid_data=None, verbose=True, saved=True):
# load model
if self.xgb_model != None:
if self.xgb_model is not None:
self.model.load_model(self.xgb_model)

self.best_valid_score = 0.
Expand All @@ -666,15 +702,15 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True):
valid_result, valid_score = self._valid_epoch(valid_data)
valid_end_time = time()
valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \
(epoch_idx, valid_end_time - valid_start_time, valid_score)
(epoch_idx, valid_end_time - valid_start_time, valid_score)
valid_result_output = 'valid result: \n' + dict2str(valid_result)
if verbose:
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)

self.best_valid_score = valid_score
self.best_valid_result = valid_result

return self.best_valid_score, self.best_valid_result

def evaluate(self, eval_data, load_best_model=True, model_file=None):
Expand Down
4 changes: 3 additions & 1 deletion recbole/utils/argument_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
'seed',
'reproducibility',
'state',
'data_path']
'data_path',
'show_progress']

training_arguments = ['epochs', 'train_batch_size',
'learner', 'learning_rate',
'training_neg_sample_num',
'training_neg_sample_distribution',
'eval_step', 'stopping_step',
'checkpoint_dir']

Expand Down
Loading

0 comments on commit b2d8664

Please sign in to comment.