Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add token2id & id2token to Dataset #511

Merged
merged 1 commit into from
Nov 19, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ class Dataset(object):
Specially, if feature is loaded from Arg ``additional_feat_suffix``, its source has type str,
which is the suffix of its local file (also the suffix written in Arg ``additional_feat_suffix``).

field2id_token (dict): Dict mapping feature name (str) to a list, which stores the original token of
this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
field2id_token (dict): Dict mapping feature name (str) to a :class:`np.ndarray`, which stores the original token
of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
is remapped to 2. Then ``field2id_token['test'] = ['[PAD]', 'token_a', 'token_b']``. (Note that 0 is
always PADDING for token-like features.)

field2token_id (dict): Dict mapping feature name (str) to a dict, which stores the token remap table
of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
is remapped to 2. Then ``field2token_id['test'] = {'[PAD]': 0, 'token_a': 1, 'token_b': 2}``.
(Note that 0 is always PADDING for token-like features.)

field2seqlen (dict): Dict mapping feature name (str) to its sequence length (int).
For sequence features, their length can be either set in config,
or set to the max sequence length of this feature.
Expand Down Expand Up @@ -116,6 +121,7 @@ def _get_preset(self):
self.field2type = {}
self.field2source = {}
self.field2id_token = {}
self.field2token_id = {}
self.field2seqlen = self.config['seq_len'] or {}
self._preloaded_weight = {}
self.benchmark_filename_list = self.config['benchmark_filename']
Expand Down Expand Up @@ -897,11 +903,13 @@ def _remap(self, remap_list):
tokens, split_point = self._concat_remaped_tokens(remap_list)
new_ids_list, mp = pd.factorize(tokens)
new_ids_list = np.split(new_ids_list + 1, split_point)
mp = ['[PAD]'] + list(mp)
mp = np.array(['[PAD]'] + list(mp))
token_id = {t: i for i, t in enumerate(mp)}

for (feat, field, ftype), new_ids in zip(remap_list, new_ids_list):
if (field not in self.field2id_token):
if field not in self.field2id_token:
self.field2id_token[field] = mp
self.field2token_id[field] = token_id
if ftype == FeatureType.TOKEN:
feat[field] = new_ids
elif ftype == FeatureType.TOKEN_SEQ:
Expand Down Expand Up @@ -1010,6 +1018,46 @@ def copy_field_property(self, dest_field, source_field):
self.field2source[dest_field] = self.field2source[source_field]
self.field2seqlen[dest_field] = self.field2seqlen[source_field]

@dlapi.set()
def token2id(self, field, tokens):
"""Map external tokens to internal ids.

Args:
field (str): Field of external tokens.
tokens (str, list or np.ndarray): External tokens.

Returns:
int or np.ndarray: The internal ids of external tokens.
"""
if isinstance(tokens, str):
if tokens in self.field2token_id[field]:
return self.field2token_id[field][tokens]
else:
raise ValueError('token [{}] is not existed')
elif isinstance(tokens, (list, np.ndarray)):
return np.array([self.token2id(field, token) for token in tokens])
else:
raise TypeError('The type of tokens [{}] is not supported')

@dlapi.set()
def id2token(self, field, ids):
"""Map internal ids to external tokens.

Args:
field (str): Field of internal ids.
ids (int, list, np.ndarray or torch.Tensor): Internal ids.

Returns:
str or np.ndarray: The external tokens of internal ids.
"""
try:
return self.field2id_token[field][ids]
except IndexError:
if isinstance(ids, list):
raise ValueError('[{}] is not a one-dimensional list'.format(ids))
else:
raise ValueError('[{}] is not a valid ids'.format(ids))

@property
@dlapi.set()
def user_num(self):
Expand Down