Skip to content

Commit

Permalink
Merge pull request #194 from chenyushuo/master
Browse files Browse the repository at this point in the history
FEA: Refactor in Dataset
  • Loading branch information
hyp1231 authored Aug 26, 2020
2 parents 1cd2df3 + 2775087 commit b909aa2
Showing 1 changed file with 82 additions and 98 deletions.
180 changes: 82 additions & 98 deletions recbox/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/8/24, 2020/8/5, 2020/8/21
# @Time : 2020/8/24, 2020/8/5, 2020/8/26
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, panxy@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -36,14 +36,7 @@ def _from_scratch(self, config):
self.field2type = {}
self.field2source = {}
self.field2id_token = {}
if config['seq_len'] is not None:
self.field2seqlen = config['seq_len']
else:
self.field2seqlen = {}

self.inter_feat = None
self.user_feat = None
self.item_feat = None
self.field2seqlen = config['seq_len'] or {}

self.model_type = self.config['MODEL_TYPE']
self.uid_field = self.config['USER_ID_FIELD']
Expand All @@ -52,30 +45,16 @@ def _from_scratch(self, config):
self.time_field = self.config['TIME_FIELD']

self.inter_feat, self.user_feat, self.item_feat = self._load_data(self.dataset_name, self.dataset_path)
self.feat_list = [feat for feat in [self.inter_feat, self.user_feat, self.item_feat] if feat is not None]

self.filter_by_inter_num(max_user_inter_num=config['max_user_inter_num'],
min_user_inter_num=config['min_user_inter_num'],
max_item_inter_num=config['max_item_inter_num'],
min_item_inter_num=config['min_item_inter_num'])

self.filter_by_field_value(lowest_val=config['lowest_val'], highest_val=config['highest_val'],
equal_val=config['equal_val'], not_equal_val=config['not_equal_val'],
drop=config['drop_filter_field'])

self._set_label_by_threshold(self.config['threshold'])

self._filter_by_inter_num()
self._filter_by_field_value()
self._reset_index()
self._remap_ID_all()

if self.config['fill_nan']:
self._fill_nan()

if self.config['normalize_field'] is not None and self.config['normalize_all'] is not None:
raise ValueError('normalize_field and normalize_all can\'t be set at the same time')
if self.config['normalize_field']:
self._normalize(self.config['normalize_field'])
elif self.config['normalize_all']:
self._normalize([_ for _ in self.field2type if ((self.field2type[_] == FeatureType.FLOAT) \
or (self.field2type[_] == FeatureType.FLOAT_SEQ))])
self._fill_nan()
self._set_label_by_threshold()
self._normalize()

def _restore_saved_dataset(self, saved_dataset):
if (saved_dataset is None) or (not os.path.isdir(saved_dataset)):
Expand All @@ -93,6 +72,9 @@ def _restore_saved_dataset(self, saved_dataset):
if os.path.isfile(cur_file_name):
df = pd.read_csv(cur_file_name)
setattr(self, '{}_feat'.format(name), df)
else:
setattr(self, '{}_feat'.format(name), None)
self.feat_list = [feat for feat in [self.inter_feat, self.user_feat, self.item_feat] if feat is not None]

self.model_type = self.config['MODEL_TYPE']
self.uid_field = self.config['USER_ID_FIELD']
Expand Down Expand Up @@ -214,34 +196,41 @@ def _float_seq(df, field): df[field] = [list(map(float, _.split(seq_separator)))
return df

def _fill_nan(self):
if not self.config['fill_nan']:
return

most_freq = SimpleImputer(missing_values=np.nan, strategy='most_frequent', copy=False)
aveg = SimpleImputer(missing_values=np.nan, strategy='mean', copy=False)

for feat in [self.inter_feat, self.user_feat, self.item_feat]:
if feat is None:
continue
for field in self.field2type:
if field not in feat:
continue
for feat in self.feat_list:
for field in feat:
ftype = self.field2type[field]
if ftype == FeatureType.TOKEN:
feat.loc[:, field] = most_freq.fit_transform(feat.loc[:, field].values.reshape(-1, 1))
feat[field] = most_freq.fit_transform(feat[field].values.reshape(-1, 1))
elif ftype == FeatureType.FLOAT:
feat.loc[:, field] = aveg.fit_transform(feat.loc[:, field].values.reshape(-1, 1))
feat[field] = aveg.fit_transform(feat[field].values.reshape(-1, 1))
elif ftype.endswith('seq'):
self.logger.warning('feature [{}] (type: {}) probably has nan, while has not been filled.'
.format(field, ftype))

def _normalize(self, fields):
for field in fields:
ftype = self.field2type[field]
if field not in self.field2type:
raise ValueError('Field [{}] doesn\'t exist'.format(field))
elif ftype != FeatureType.FLOAT and ftype != FeatureType.FLOAT_SEQ:
self.logger.warning('{} is not a FLOAT/FLOAT_SEQ feat, which will not be normalized.'.format(field))
for feat in [self.inter_feat, self.user_feat, self.item_feat]:
if feat is None:
continue
def _normalize(self):
if self.config['normalize_field'] is not None and self.config['normalize_all'] is not None:
raise ValueError('normalize_field and normalize_all can\'t be set at the same time')

if self.config['normalize_field']:
fields = self.config['normalize_field']
for field in fields:
ftype = self.field2type[field]
if field not in self.field2type:
raise ValueError('Field [{}] doesn\'t exist'.format(field))
elif ftype != FeatureType.FLOAT and ftype != FeatureType.FLOAT_SEQ:
self.logger.warning('{} is not a FLOAT/FLOAT_SEQ feat, which will not be normalized.'.format(field))
elif self.config['normalize_all']:
fields = self.fields([FeatureType.FLOAT, FeatureType.FLOAT_SEQ])
else:
return

for feat in self.feat_list:
for field in feat:
if field not in fields:
continue
Expand All @@ -262,55 +251,56 @@ def _normalize(self, fields):
lst = np.split(lst, split_point)
feat[field] = lst

def filter_by_inter_num(self, max_user_inter_num=None, min_user_inter_num=None,
max_item_inter_num=None, min_item_inter_num=None):
ban_users = self._get_illegal_ids_by_inter_num(source='user', max_num=max_user_inter_num,
min_num=min_user_inter_num)
ban_items = self._get_illegal_ids_by_inter_num(source='item', max_num=max_item_inter_num,
min_num=min_item_inter_num)
def _filter_by_inter_num(self):
ban_users = self._get_illegal_ids_by_inter_num(field=self.uid_field,
max_num=self.config['max_user_inter_num'],
min_num=self.config['min_user_inter_num'])
ban_items = self._get_illegal_ids_by_inter_num(field=self.iid_field,
max_num=self.config['max_item_inter_num'],
min_num=self.config['min_item_inter_num'])

if len(ban_users) == 0 and len(ban_items) == 0:
return

if self.user_feat is not None:
selected_user = ~self.user_feat[self.uid_field].isin(ban_users)
self.user_feat = self.user_feat[selected_user].reset_index(drop=True)
dropped_user = self.user_feat[self.uid_field].isin(ban_users)
self.user_feat.drop(self.user_feat.index[dropped_user], inplace=True)

if self.item_feat is not None:
selected_item = ~self.item_feat[self.iid_field].isin(ban_users)
self.item_feat = self.item_feat[selected_item].reset_index(drop=True)
dropped_item = self.item_feat[self.iid_field].isin(ban_items)
self.item_feat.drop(self.item_feat.index[dropped_item], inplace=True)

selected_inter = pd.Series(True, index=self.inter_feat.index)
dropped_inter = pd.Series(False, index=self.inter_feat.index)
if self.uid_field:
selected_inter &= ~self.inter_feat[self.uid_field].isin(ban_users)
dropped_inter |= self.inter_feat[self.uid_field].isin(ban_users)
if self.iid_field:
selected_inter &= ~self.inter_feat[self.iid_field].isin(ban_items)
self.inter_feat = self.inter_feat[selected_inter].reset_index(drop=True)
dropped_inter |= self.inter_feat[self.iid_field].isin(ban_items)
self.inter_feat.drop(self.inter_feat.index[dropped_inter], inplace=True)

def _get_illegal_ids_by_inter_num(self, source, max_num=None, min_num=None):
if source not in {'user', 'item'}:
raise ValueError('source [{}] should be user or item'.format(source))
def _get_illegal_ids_by_inter_num(self, field, max_num=None, min_num=None):
if field is None:
return set()
if max_num is None and min_num is None:
return set()

max_num = max_num or np.inf
min_num = min_num or -1

field_name = self.uid_field if source == 'user' else self.iid_field
if field_name is None:
return set()

ids = self.inter_feat[field_name].values
ids = self.inter_feat[field].values
inter_num = Counter(ids)
ids = {id_ for id_ in inter_num if inter_num[id_] < min_num or inter_num[id_] > max_num}
return ids

def filter_by_field_value(self, lowest_val=None, highest_val=None,
equal_val=None, not_equal_val=None, drop=False):
self._filter_by_field_value(lowest_val, lambda x, y: x >= y, drop)
self._filter_by_field_value(highest_val, lambda x, y: x <= y, drop)
self._filter_by_field_value(equal_val, lambda x, y: x == y, drop)
self._filter_by_field_value(not_equal_val, lambda x, y: x != y, drop)
def _filter_by_field_value(self):
drop_field = self.config['drop_filter_field']
changed = False
changed |= self._drop_by_value(self.config['lowest_val'], lambda x, y: x < y, drop_field)
changed |= self._drop_by_value(self.config['highest_val'], lambda x, y: x > y, drop_field)
changed |= self._drop_by_value(self.config['equal_val'], lambda x, y: x != y, drop_field)
changed |= self._drop_by_value(self.config['not_equal_val'], lambda x, y: x == y, drop_field)

if not changed:
return

if self.user_feat is not None:
remained_uids = set(self.user_feat[self.uid_field].values)
Expand All @@ -327,41 +317,35 @@ def filter_by_field_value(self, lowest_val=None, highest_val=None,
remained_inter &= self.inter_feat[self.uid_field].isin(remained_uids)
if self.iid_field is not None:
remained_inter &= self.inter_feat[self.iid_field].isin(remained_iids)
self.inter_feat = self.inter_feat[remained_inter]
self.inter_feat.drop(self.inter_feat.index[~remained_inter], inplace=True)

for source in {'user', 'item', 'inter'}:
feat = getattr(self, '{}_feat'.format(source))
if feat is not None:
feat.reset_index(drop=True, inplace=True)
def _reset_index(self):
for feat in self.feat_list:
feat.reset_index(drop=True, inplace=True)

def _filter_by_field_value(self, val, cmp, drop=False):
def _drop_by_value(self, val, cmp, drop_field=False):
if val is None:
return
all_feats = []
for source in ['inter', 'user', 'item']:
cur_feat = getattr(self, '{}_feat'.format(source))
if cur_feat is not None:
all_feats.append([source, cur_feat])
return False
for field in val:
if field not in self.field2type:
raise ValueError('field [{}] not defined in dataset'.format(field))
for source, cur_feat in all_feats:
if field in cur_feat:
new_feat = cur_feat[cmp(cur_feat[field].values, val[field])]
setattr(self, '{}_feat'.format(source), new_feat)
if drop:
for feat in self.feat_list:
if field in feat:
feat.drop(feat.index[cmp(feat[field].values, val[field])], inplace=True)
if drop_field:
self._del_col(field)
return True

def _del_col(self, field):
for source in ['inter', 'user', 'item']:
cur_feat = getattr(self, '{}_feat'.format(source))
if cur_feat is not None and field in cur_feat:
setattr(self, '{}_feat'.format(source), cur_feat.drop(columns=field))
for feat in self.feat_list:
if field in feat:
feat.drop(columns=field, inplace=True)
for dct in [self.field2id_token, self.field2seqlen, self.field2source, self.field2type]:
if field in dct:
del dct[field]

def _set_label_by_threshold(self, threshold):
def _set_label_by_threshold(self):
threshold = self.config['threshold']
if threshold is None:
return

Expand Down

0 comments on commit b909aa2

Please sign in to comment.