From b3727b69feb1215216f54f54f6337c53c633395f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 18 Nov 2020 20:49:47 +0800 Subject: [PATCH] FEA: add token2id & id2token to Dataset --- recbole/data/dataset/dataset.py | 56 ++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 98c821a9d..aadf4eeeb 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -55,11 +55,16 @@ class Dataset(object): Specially, if feature is loaded from Arg ``additional_feat_suffix``, its source has type str, which is the suffix of its local file (also the suffix written in Arg ``additional_feat_suffix``). - field2id_token (dict): Dict mapping feature name (str) to a list, which stores the original token of - this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b`` + field2id_token (dict): Dict mapping feature name (str) to a :class:`np.ndarray`, which stores the original token + of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b`` is remapped to 2. Then ``field2id_token['test'] = ['[PAD]', 'token_a', 'token_b']``. (Note that 0 is always PADDING for token-like features.) + field2token_id (dict): Dict mapping feature name (str) to a dict, which stores the token remap table + of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b`` + is remapped to 2. Then ``field2token_id['test'] = {'[PAD]': 0, 'token_a': 1, 'token_b': 2}``. + (Note that 0 is always PADDING for token-like features.) + field2seqlen (dict): Dict mapping feature name (str) to its sequence length (int). For sequence features, their length can be either set in config, or set to the max sequence length of this feature. @@ -116,6 +121,7 @@ def _get_preset(self): self.field2type = {} self.field2source = {} self.field2id_token = {} + self.field2token_id = {} self.field2seqlen = self.config['seq_len'] or {} self._preloaded_weight = {} self.benchmark_filename_list = self.config['benchmark_filename'] @@ -897,11 +903,13 @@ def _remap(self, remap_list): tokens, split_point = self._concat_remaped_tokens(remap_list) new_ids_list, mp = pd.factorize(tokens) new_ids_list = np.split(new_ids_list + 1, split_point) - mp = ['[PAD]'] + list(mp) + mp = np.array(['[PAD]'] + list(mp)) + token_id = {t: i for i, t in enumerate(mp)} for (feat, field, ftype), new_ids in zip(remap_list, new_ids_list): - if (field not in self.field2id_token): + if field not in self.field2id_token: self.field2id_token[field] = mp + self.field2token_id[field] = token_id if ftype == FeatureType.TOKEN: feat[field] = new_ids elif ftype == FeatureType.TOKEN_SEQ: @@ -1010,6 +1018,46 @@ def copy_field_property(self, dest_field, source_field): self.field2source[dest_field] = self.field2source[source_field] self.field2seqlen[dest_field] = self.field2seqlen[source_field] + @dlapi.set() + def token2id(self, field, tokens): + """Map external tokens to internal ids. + + Args: + field (str): Field of external tokens. + tokens (str, list or np.ndarray): External tokens. + + Returns: + int or np.ndarray: The internal ids of external tokens. + """ + if isinstance(tokens, str): + if tokens in self.field2token_id[field]: + return self.field2token_id[field][tokens] + else: + raise ValueError('token [{}] is not existed') + elif isinstance(tokens, (list, np.ndarray)): + return np.array([self.token2id(field, token) for token in tokens]) + else: + raise TypeError('The type of tokens [{}] is not supported') + + @dlapi.set() + def id2token(self, field, ids): + """Map internal ids to external tokens. + + Args: + field (str): Field of internal ids. + ids (int, list, np.ndarray or torch.Tensor): Internal ids. + + Returns: + str or np.ndarray: The external tokens of internal ids. + """ + try: + return self.field2id_token[field][ids] + except IndexError: + if isinstance(ids, list): + raise ValueError('[{}] is not a one-dimensional list'.format(ids)) + else: + raise ValueError('[{}] is not a valid ids'.format(ids)) + @property @dlapi.set() def user_num(self):