From fdc582a0c3d43e5907aba60809d303a54e90c557 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Mon, 20 Sep 2021 17:03:39 -0700 Subject: [PATCH] update esm tokenization with save and special token handling --- .../models/esm/tokenization_esm.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/transformers/models/esm/tokenization_esm.py b/src/transformers/models/esm/tokenization_esm.py index d94091580d55b0..f842ea5150dd37 100644 --- a/src/transformers/models/esm/tokenization_esm.py +++ b/src/transformers/models/esm/tokenization_esm.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for ESM.""" +import os +from typing import List, Optional, Union +import warnings +import requests + +from transformers.file_utils import cached_path, hf_bucket_url, is_offline_mode, is_remote_url +from transformers.tokenization_utils_base import get_fast_tokenizer_file from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging + logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} @@ -77,3 +85,18 @@ def token_to_id(self, token: str) -> int: def id_to_token(self, index: int) -> str: return self._id_to_token.get(index, self.unk_token) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + assert token_ids_1 is None, "not supporting multiple sentences" + cls_ : List[int] = [self.cls_token_id] + return cls_ + token_ids_0 + + def save_vocabulary(self, save_directory, filename_prefix): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt' + ) + with open(vocab_file, 'w') as f: + f.write('\n'.join(self.all_tokens)) + return (vocab_file,)