Skip to content

Commit

Permalink
Merge pull request #20 from RUCAIBox/0.2.x
Browse files Browse the repository at this point in the history
0.2.x
  • Loading branch information
2017pxy authored Jan 2, 2021
2 parents 0c34a02 + 1c11a5b commit d59f12d
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 47 deletions.
12 changes: 8 additions & 4 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,19 @@ def _normalize(self):
lst = feat[field].values
mx, mn = max(lst), min(lst)
if mx == mn:
raise ValueError('All the same value in [{}] from [{}_feat]'.format(field, feat))
feat[field] = (lst - mn) / (mx - mn)
self.logger.warning('All the same value in [{}] from [{}_feat]'.format(field, feat))
feat[field] = 1.0
else:
feat[field] = (lst - mn) / (mx - mn)
elif ftype == FeatureType.FLOAT_SEQ:
split_point = np.cumsum(feat[field].agg(len))[:-1]
lst = feat[field].agg(np.concatenate)
mx, mn = max(lst), min(lst)
if mx == mn:
raise ValueError('All the same value in [{}] from [{}_feat]'.format(field, feat))
lst = (lst - mn) / (mx - mn)
self.logger.warning('All the same value in [{}] from [{}_feat]'.format(field, feat))
lst = 1.0
else:
lst = (lst - mn) / (mx - mn)
lst = np.split(lst, split_point)
feat[field] = lst

Expand Down
59 changes: 41 additions & 18 deletions recbole/data/dataset/xgboost_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, config, saved_dataset=None):
super().__init__(config, saved_dataset=saved_dataset)

def _judge_token_and_convert(self, feat):
# get columns whose type is token
col_list = []
for col_name in feat:
if col_name == self.uid_field or col_name == self.iid_field:
Expand All @@ -35,31 +36,53 @@ def _judge_token_and_convert(self, feat):
col_list.append(col_name)
elif self.field2type[col_name] == FeatureType.TOKEN_SEQ or self.field2type[col_name] == FeatureType.FLOAT_SEQ:
feat = feat.drop([col_name], axis=1, inplace=False)
feat = pd.get_dummies(feat, sparse = True, columns = col_list)
for col_name in feat.columns.values.tolist():
if col_name not in self.field2type.keys():
self.field2type[col_name] = FeatureType.TOKEN

# get hash map
for col in col_list:
self.hash_map[col] = dict({})
self.hash_count[col] = 0

del_col = []
for col in self.hash_map:
if col in feat.keys():
for value in feat[col]:
#print(value)
if value not in self.hash_map[col]:
self.hash_map[col][value] = self.hash_count[col]
self.hash_count[col] = self.hash_count[col] + 1
if self.hash_count[col] > self.config['token_num_threhold']:
del_col.append(col)
break

for col in del_col:
del self.hash_count[col]
del self.hash_map[col]
col_list.remove(col)
self.convert_col_list.extend(col_list)

# transform the original data
for col in self.hash_map.keys():
if col in feat.keys():
feat[col] = feat[col].map(self.hash_map[col])

return feat

def _convert_token_to_onehot(self):
"""Convert the data of token type to onehot form
def _convert_token_to_hash(self):
"""Convert the data of token type to hash form
"""
self.hash_map = {}
self.hash_count = {}
self.convert_col_list = []
if self.config['convert_token_to_onehot'] == True:
feat_list = []
for feat in (self.inter_feat, self.user_feat, self.item_feat):
feat = self._judge_token_and_convert(feat)
if feat is not None:
feat = self._judge_token_and_convert(feat)
feat_list.append(feat)
self.inter_feat_xgb = feat_list[0]
self.user_feat_xgb = feat_list[1]
self.item_feat_xgb = feat_list[2]
self.inter_feat = self.inter_feat_xgb
self.user_feat = self.user_feat_xgb
self.item_feat = self.item_feat
else:
self.inter_feat_xgb = self.inter_feat
self.user_feat_xgb = self.user_feat
self.item_feat_xgb = self.item_feat
self.inter_feat = feat_list[0]
self.user_feat = feat_list[1]
self.item_feat = feat_list[2]

def _from_scratch(self):
"""Load dataset from scratch.
Expand All @@ -71,7 +94,7 @@ def _from_scratch(self):
self._get_field_from_config()
self._load_data(self.dataset_name, self.dataset_path)
self._data_processing()
self._convert_token_to_onehot()
self._convert_token_to_hash()
self._change_feat_format()

def join(self, df):
Expand Down
2 changes: 1 addition & 1 deletion recbole/properties/dataset/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ not_equal_val: ~
fields_in_same_space: ~
preload_weight: ~
normalize_field: ~
normalize_all: True
normalize_all: ~

# Sequential Model Needed
ITEM_LIST_LENGTH_FIELD: item_length
Expand Down
6 changes: 3 additions & 3 deletions recbole/properties/model/xgboost.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Type of training method
convert_token_to_onehot: True
convert_token_to_onehot: False
token_num_threhold: 10000

# DMatrix

xgb_weight: ~
xgb_base_margin: ~
xgb_missing: ~
Expand All @@ -25,7 +25,7 @@ xgb_params:
eta: 1
seed: 2020
# nthread: -1
xgb_num_boost_round: 200
xgb_num_boost_round: 500
# xgb_evals: ~
xgb_obj: ~
xgb_feval: ~
Expand Down
74 changes: 53 additions & 21 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def __init__(self, config, model):
self.logger = getLogger()
self.label_field = config['LABEL_FIELD']
self.xgb_model = config['xgb_model']
self.convert_token_to_onehot = self.config['convert_token_to_onehot']

# DMatrix params
self.weight = config['xgb_weight']
Expand Down Expand Up @@ -633,54 +634,85 @@ def __init__(self, config, model):
saved_model_file = '{}-{}.pth'.format(self.config['model'], get_local_time())
self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)

def _interaction_to_DMatrix(self, interaction):
def _interaction_to_DMatrix(self, dataloader):
r"""Convert data format from interaction to DMatrix
Args:
interaction (Interaction): Data in the form of 'Interaction'.
dataloader (XgboostDataLoader): xgboost dataloader.
Returns:
DMatrix: Data in the form of 'DMatrix'.
"""
interaction = dataloader.dataset[:]
interaction_np = interaction.numpy()
cur_data = np.array([])
columns = []
for key, value in interaction_np.items():
value = np.resize(value, (value.shape[0], 1))
if key != self.label_field:
columns.append(key)
if cur_data.shape[0] == 0:
cur_data = value
else:
cur_data = np.hstack((cur_data, value))

return self.xgb.DMatrix(data=cur_data,
label=interaction_np[self.label_field],
weight=self.weight,
base_margin=self.base_margin,
missing=self.missing,
silent=self.silent,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=self.nthread)
if self.convert_token_to_onehot == True:
from scipy import sparse
from scipy.sparse import dok_matrix
convert_col_list = dataloader.dataset.convert_col_list
hash_count = dataloader.dataset.hash_count

new_col = cur_data.shape[1] - len(convert_col_list)
for key, values in hash_count.items():
new_col = new_col + values
onehot_data = dok_matrix((cur_data.shape[0], new_col))

cur_j = 0
new_j = 0

for key in columns:
if key in convert_col_list:
for i in range(cur_data.shape[0]):
onehot_data[i, int(new_j + cur_data[i, cur_j])] = 1
new_j = new_j + hash_count[key] - 1
else:
for i in range(cur_data.shape[0]):
onehot_data[i, new_j] = cur_data[i, cur_j]
cur_j = cur_j + 1
new_j = new_j + 1

cur_data = sparse.csc_matrix(onehot_data)

return self.xgb.DMatrix(data = cur_data,
label = interaction_np[self.label_field],
weight = self.weight,
base_margin = self.base_margin,
missing = self.missing,
silent = self.silent,
feature_names = self.feature_names,
feature_types = self.feature_types,
nthread = self.nthread)

def _train_at_once(self, train_data, valid_data):
r"""
Args:
train_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
valid_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
"""
self.dtrain = self._interaction_to_DMatrix(train_data.dataset[:])
self.dvalid = self._interaction_to_DMatrix(valid_data.dataset[:])
self.evals = [(self.dtrain, 'train'), (self.dvalid, 'valid')]
self.model = self.xgb.train(self.params, self.dtrain, self.num_boost_round,
self.evals, self.obj, self.feval, self.maximize,
self.early_stopping_rounds, self.evals_result,
self.verbose_eval, self.xgb_model, self.callbacks)
self.dtrain = self._interaction_to_DMatrix(train_data)
self.dvalid = self._interaction_to_DMatrix(valid_data)
self.evals = [(self.dtrain,'train'),(self.dvalid, 'valid')]
self.model = self.xgb.train(self.params, self.dtrain, self.num_boost_round,
self.evals, self.obj, self.feval, self.maximize,
self.early_stopping_rounds, self.evals_result,
self.verbose_eval, self.xgb_model, self.callbacks)

self.model.save_model(self.saved_model_file)
self.xgb_model = self.saved_model_file

def _valid_epoch(self, valid_data):
r"""
Args:
valid_data (XgboostDataLoader): XgboostDataLoader, which is the same with GeneralDataLoader.
"""
Expand Down Expand Up @@ -720,7 +752,7 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None):
self.eval_pred = torch.Tensor()
self.eval_true = torch.Tensor()

self.deval = self._interaction_to_DMatrix(eval_data.dataset[:])
self.deval = self._interaction_to_DMatrix(eval_data)
self.eval_true = torch.Tensor(self.deval.get_label())
self.eval_pred = torch.Tensor(self.model.predict(self.deval))

Expand Down

0 comments on commit d59f12d

Please sign in to comment.