diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 195efa6e107b..c9c9f87b25ad 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -422,7 +422,8 @@ def update_model_kwargs_for_generation(outputs, # method. # update cache - if isinstance(outputs, tuple): + if isinstance(outputs, + tuple) and not isinstance(outputs[1], paddle.Tensor): model_kwargs["cache"] = outputs[1] # update token_type_ids with last value diff --git a/paddlenlp/transformers/t5/modeling.py b/paddlenlp/transformers/t5/modeling.py index efeffa66b67e..e054426a0001 100644 --- a/paddlenlp/transformers/t5/modeling.py +++ b/paddlenlp/transformers/t5/modeling.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. # @@ -31,6 +30,12 @@ 'T5ForConditionalGeneration', ] +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", +] + def finfo(dtype): if dtype == paddle.float32: @@ -107,6 +112,27 @@ def forward(self, hidden_states): return hidden_states +class T5DenseGatedSiluDense(nn.Layer): + """ + Construct a dense-gated_gelu-dense module. + """ + + def __init__(self, d_model, d_ff, dropout_rate): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias_attr=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias_attr=False) + self.wo = nn.Linear(d_ff, d_model, bias_attr=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, hidden_states): + hidden_silu = F.silu(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_silu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + class T5LayerFF(nn.Layer): def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon, @@ -117,6 +143,9 @@ def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon, elif feed_forward_proj == "gated-gelu": self.DenseReluDense = T5DenseGatedGeluDense(d_model, d_ff, dropout_rate) + elif feed_forward_proj == "gated-silu": + self.DenseReluDense = T5DenseGatedSiluDense(d_model, d_ff, + dropout_rate) else: raise ValueError( f"{feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" @@ -522,6 +551,7 @@ def forward( output_attentions=output_attentions, ) hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ 2:] # Keep self-attention outputs and relative position weights @@ -989,7 +1019,7 @@ def forward(self, # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: + if not use_cache: layer_outputs = layer_outputs[:1] + (None, ) + layer_outputs[1:] hidden_states, present_key_value_state = layer_outputs[:2] @@ -1040,8 +1070,6 @@ def get_extended_attention_mask(self, attention_mask, input_shape): causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]), [batch_size, seq_length, 1 ]) <= seq_ids.unsqueeze(axis=[0, 2]) - # in case cache are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.astype(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: @@ -1062,6 +1090,35 @@ def get_extended_attention_mask(self, attention_mask, input_shape): 1) * attention_mask.unsqueeze([1, 2]) else: extended_attention_mask = attention_mask.unsqueeze([1, 2]) + elif attention_mask.ndim == 4: + if self.is_decoder: + batch_size, seq_length = input_shape + seq_ids = paddle.arange(seq_length) + causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]), + [batch_size, seq_length, 1 + ]) <= seq_ids.unsqueeze(axis=[0, 2]) + # in case cache are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.astype(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[-1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = paddle.concat( + [ + paddle.ones( + [batch_size, seq_length, prefix_seq_len], + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask.unsqueeze( + 1) * attention_mask + else: + extended_attention_mask = attention_mask else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" @@ -1072,10 +1129,12 @@ def get_extended_attention_mask(self, attention_mask, input_shape): return extended_attention_mask def invert_attention_mask(self, encoder_attention_mask): - if encoder_attention_mask.ndim == 3: + if encoder_attention_mask.ndim == 4: + encoder_extended_attention_mask = encoder_attention_mask + elif encoder_attention_mask.ndim == 3: encoder_extended_attention_mask = encoder_attention_mask.unsqueeze( 1) - if encoder_attention_mask.ndim == 2: + elif encoder_attention_mask.ndim == 2: encoder_extended_attention_mask = encoder_attention_mask.unsqueeze( [1, 2]) encoder_extended_attention_mask = encoder_extended_attention_mask.astype( @@ -1176,6 +1235,13 @@ def __init__(self, self.d_model = d_model self.initializer_factor = initializer_factor + if num_decoder_layers is None and num_layers is None: + raise ValueError( + "You have to specify either num_decoder_layers or num_layers or both." + ) + elif num_decoder_layers is None: + num_decoder_layers = num_layers + self.shared = nn.Embedding(vocab_size, d_model) self.encoder = T5Stack(d_model, num_layers, @@ -1401,9 +1467,10 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_output_embeddings(self): - if not self.t5.config["tie_word_embeddings"]: + if self.t5.config["tie_word_embeddings"]: return self.t5.shared - return self.lm_head + else: + return self.lm_head def get_encoder(self): return self.t5.encoder @@ -1514,7 +1581,10 @@ def forward(self, output_attentions=output_attentions, output_hidden_states=output_hidden_states) - hidden_states = encoder_output[0] + if isinstance(encoder_output, (tuple, list)): + hidden_states = encoder_output[0] + else: + hidden_states = encoder_output if labels is not None and decoder_input_ids is None: # get decoder inputs from shifting lm labels to the right @@ -1559,6 +1629,9 @@ def forward(self, loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]), labels.flatten()) + if not isinstance(encoder_output, (list, tuple)): + encoder_output = (encoder_output, ) + output = (lm_logits, ) + decoder_outputs[1:] + encoder_output return ((loss, ) + output) if loss is not None else output diff --git a/paddlenlp/transformers/t5/tokenizer.py b/paddlenlp/transformers/t5/tokenizer.py index 7f78caa80264..549a9bdccf9c 100644 --- a/paddlenlp/transformers/t5/tokenizer.py +++ b/paddlenlp/transformers/t5/tokenizer.py @@ -24,6 +24,12 @@ 'T5Tokenizer', ] +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "t5-small": 512, + "t5-base": 512, + "t5-large": 512, +} + class T5Tokenizer(AlbertEnglishTokenizer): """ @@ -88,6 +94,8 @@ class T5Tokenizer(AlbertEnglishTokenizer): }, } + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + def __init__(self, sentencepiece_model_file, do_lower_case=False, @@ -98,6 +106,7 @@ def __init__(self, pad_token="", extra_ids=100, additional_special_tokens=[], + sp_model_kwargs=None, **kwargs): # Add extra_ids to the special token list @@ -123,28 +132,54 @@ def __init__(self, self.extra_ids = extra_ids self.sentencepiece_model_file = sentencepiece_model_file - self.sp_model = spm.SentencePieceProcessor() + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(sentencepiece_model_file) def __call__(self, text, text_pair=None, - max_seq_len=None, + max_length=None, stride=0, is_split_into_words=False, - pad_to_max_seq_len=False, - truncation_strategy="longest_first", + padding=None, + truncation="longest_first", return_position_ids=False, return_token_type_ids=False, return_attention_mask=True, return_length=False, return_overflowing_tokens=False, - return_special_tokens_mask=False): + return_special_tokens_mask=False, + **kwargs): + if "pad_to_max_seq_len" in kwargs and padding is None: + pad_to_max_seq_len = kwargs.pop("pad_to_max_seq_len") + padding = "max_length" if pad_to_max_seq_len else False + elif padding is None: + padding = False + + if "max_seq_len" in kwargs and max_length is None: + max_length = kwargs["max_seq_len"] + + if "truncation_strategy" in kwargs and kwargs[ + "truncation_strategy"] != "longest_first": + truncation = kwargs["truncation_strategy"] + return super(T5Tokenizer, self).__call__( - text, text_pair, max_seq_len, stride, is_split_into_words, - pad_to_max_seq_len, truncation_strategy, return_position_ids, - return_token_type_ids, return_attention_mask, return_length, - return_overflowing_tokens, return_special_tokens_mask) + text=text, + text_pair=text_pair, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + padding=padding, + truncation=truncation, + return_position_ids=return_position_ids, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_length=return_length, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + **kwargs) @property def vocab_size(self): @@ -254,36 +289,6 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode_pieces(current_sub_tokens) return out_string.strip() - def decode(self, - token_ids, - skip_special_tokens=False, - clean_up_tokenization_spaces=True): - """ - Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special - tokens and clean up tokenization spaces. - - Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. - - Args: - token_ids (Union[List[int], Tensor]): - List of tokenized input ids. - skip_special_tokens (bool, optional): - Whether or not to remove special tokens in the decoding. Defaults to `False`. - clean_up_tokenization_spaces (bool, optional): - Whether or not to clean up the tokenization spaces. Defaults to `True`. - - Returns: - str: The decoded sentence. - """ - if hasattr(token_ids, "tolist"): - token_ids = token_ids.tolist() - text = self.convert_tokens_to_string( - self.convert_ids_to_tokens(token_ids, - skip_special_tokens=skip_special_tokens)) - if clean_up_tokenization_spaces: - text = self.clean_up_tokenization(text) - return text - def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" if token.startswith("") + self.assertEqual(vocab_keys[1], "") + self.assertEqual(vocab_keys[-1], "") + self.assertEqual(len(vocab_keys), 1_101) + + def test_vocab_size(self): + self.assertEqual(self.get_tokenizer().vocab_size, 1_100) + + def test_full_tokenizer(self): + tokenizer = T5Tokenizer(SAMPLE_VOCAB) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), + [285, 46, 10, 170, 382]) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [ + 8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, + 80, 6, 0, 4 + ]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + def t5_base_tokenizer(self): + return T5Tokenizer.from_pretrained("t5-base") + + def get_tokenizer(self, **kwargs) -> T5Tokenizer: + return self.tokenizer_class.from_pretrained(self.tmpdirname, + pad_token=None, + **kwargs) + + def test_eos_treatment(self): + tokenizer = self.t5_base_tokenizer() + batch_with_eos_added = tokenizer( + ["hi", "I went to the gym", ""]) + batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) + self.assertListEqual(batch_with_eos_added["input_ids"], + batch_without_eos_added["input_ids"]) + + def test_prepare_batch(self): + tokenizer = self.t5_base_tokenizer() + src_text = [ + "A long paragraph for summarization.", + "Another paragraph for summarization." + ] + expected_src_tokens = [ + 71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id + ] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + self.assertIsInstance(batch, BatchEncoding) + + result = list(batch["input_ids"].tolist()[0]) + + self.assertListEqual(expected_src_tokens, result) + + self.assertEqual([2, 9], batch["input_ids"].shape) + self.assertEqual([2, 9], batch.attention_mask.shape) + + def test_empty_target_text(self): + tokenizer = self.t5_base_tokenizer() + src_text = [ + "A long paragraph for summarization.", + "Another paragraph for summarization." + ] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) + + def test_max_length(self): + tokenizer = self.t5_base_tokenizer() + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + targets = tokenizer(text=tgt_text, + max_length=32, + padding="max_length", + truncation=True, + return_tensors=FRAMEWORK) + self.assertEqual(32, targets["input_ids"].shape[1]) + + def test_outputs_not_longer_than_maxlen(self): + tokenizer = self.t5_base_tokenizer() + + batch = tokenizer(["I am a small frog" * 1000, "I am a small frog"], + padding=True, + truncation=True, + return_tensors=FRAMEWORK) + self.assertIsInstance(batch, BatchEncoding) + # Since T5 does NOT have a max input length, + # this test should be changed to the following in Transformers v5: + # self.assertEqual(batch["input_ids"].shape, (2, 8001)) + self.assertEqual(batch["input_ids"].shape, [2, 512]) + + def test_eos_in_input(self): + tokenizer = self.t5_base_tokenizer() + src_text = ["A long paragraph for summarization. "] + tgt_text = ["Summary of the text. "] + expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] + expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1] + + batch = tokenizer(src_text, text_target=tgt_text) + + self.assertEqual(expected_src_tokens, batch["input_ids"][0]) + # self.assertEqual(expected_tgt_tokens, batch["labels"][0]) + + def test_token_type_ids(self): + src_text_1 = ["A first paragraph for summarization."] + src_text_2 = ["A second paragraph for summarization."] + + tokenizer = self.t5_base_tokenizer() + + slow_token_type_ids = tokenizer( + src_text_1, + src_text_2, + add_special_tokens=True, + return_token_type_ids=True)["token_type_ids"] + + self.assertEqual(len(slow_token_type_ids[0]), 18) + + def test_special_tokens_initialization_with_non_empty_additional_special_tokens( + self): + tokenizer_list = [] + tokenizer_list.append((self.tokenizer_class, self.get_tokenizer())) + + for tokenizer_class, tokenizer_utils in tokenizer_list: + + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer_utils.save_pretrained(tmp_dir) + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), + encoding="utf-8") as json_file: + special_tokens_map = json.load(json_file) + + with open(os.path.join(tmp_dir, "tokenizer_config.json"), + encoding="utf-8") as json_file: + tokenizer_config = json.load(json_file) + + added_tokens_extra_ids = [f"" for i in range(100)] + + special_tokens_map[ + "additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + tokenizer_config[ + "additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), + "w", + encoding="utf-8") as outfile: + json.dump(special_tokens_map, outfile) + with open(os.path.join(tmp_dir, "tokenizer_config.json"), + "w", + encoding="utf-8") as outfile: + json.dump(tokenizer_config, outfile) + + # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes + # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and + # "special_tokens_map.json" files + tokenizer_without_change_in_init = tokenizer_class.from_pretrained( + tmp_dir, ) + self.assertIn( + "an_additional_special_token", + tokenizer_without_change_in_init.additional_special_tokens) + # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # ByT5Tokenization no vocab + self.assertEqual( + ["an_additional_special_token"], + tokenizer_without_change_in_init.convert_ids_to_tokens( + tokenizer_without_change_in_init.convert_tokens_to_ids( + ["an_additional_special_token"])), + ) + + # Now we test that we can change the value of additional_special_tokens in the from_pretrained + new_added_tokens = added_tokens_extra_ids + [ + AddedToken("a_new_additional_special_token", lstrip=True) + ] + tokenizer = tokenizer_class.from_pretrained( + tmp_dir, + additional_special_tokens=new_added_tokens, + ) + + self.assertIn("a_new_additional_special_token", + tokenizer.additional_special_tokens) + self.assertEqual( + ["a_new_additional_special_token"], + tokenizer.convert_ids_to_tokens( + tokenizer.convert_tokens_to_ids( + ["a_new_additional_special_token"])), + ) + + # overwritten from `test_tokenization_common` since T5 has no max length + def test_pretrained_model_lists(self): + # We should have at least one default checkpoint for each tokenizer + # We should specify the max input length as well (used in some part to list the pretrained checkpoints) + self.assertGreaterEqual( + len(self.tokenizer_class.pretrained_resource_files_map), 1) + self.assertGreaterEqual( + len( + list( + self.tokenizer_class.pretrained_resource_files_map.values()) + [0]), 1) + + def test_offsets_mapping(self): + pass diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index 8d80e668290b..c6031f641971 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -83,19 +83,19 @@ def _get_logits_processor_and_kwargs( forced_bos_token_id=None, forced_eos_token_id=None, max_length=None, - diversity_penalty=None, + diversity_rate=None, ): process_kwargs = { "min_length": 1 if max_length is None else max_length - 1, "repetition_penalty": 1.2, } - if diversity_penalty is not None: - process_kwargs["diversity_rate"] = diversity_penalty + if diversity_rate is not None: + process_kwargs["diversity_rate"] = diversity_rate logits_processor = LogitsProcessorList(([ HammingDiversityLogitsProcessor( - diversity_penalty, num_beams=2, num_beam_groups=2), - ] if diversity_penalty is not None else []) + ([ + diversity_rate, num_beams=2, num_beam_groups=2), + ] if diversity_rate is not None else []) + ([ MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id ), ] if eos_token_id is not None else []) + ([ @@ -143,7 +143,7 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, "num_beams": 2, "num_return_sequences": num_return_sequences, "num_beam_groups": 2, # one beam per group - "diversity_penalty": 2.0, + "diversity_rate": 2.0, } beam_scorer = BeamSearchScorer( batch_size=batch_size, @@ -171,6 +171,9 @@ def _get_encoder_outputs( input_ids, attention_mask=attention_mask, ) + if isinstance(encoder_outputs, (list, tuple)): + encoder_outputs = encoder_outputs[0] + encoder_outputs = encoder_outputs.repeat_interleave(num_interleave, axis=0) @@ -368,6 +371,7 @@ def _group_beam_search_generate( logits_processor, logits_process_kwargs, ): + beam_kwargs.pop("diversity_rate") model.eval() with paddle.no_grad(): output_generate = model.generate( @@ -593,7 +597,7 @@ def test_group_beam_search_generate(self): getattr(config, "forced_bos_token_id", None), getattr(config, "forced_eos_token_id", None), max_length, - diversity_penalty=2.0, + diversity_rate=2.0, ) # check `generate()` and `group_beam_search()` are equal @@ -790,7 +794,7 @@ def test_diverse_beam_search(self): num_beams=4, num_return_sequences=3, num_beam_groups=4, - diversity_penalty=2.0, + diversity_rate=2.0, ) generated_text = bart_tokenizer.batch_decode(outputs, diff --git a/tests/transformers/test_tokenizer_common.py b/tests/transformers/test_tokenizer_common.py index 3316d91df773..2aae52804272 100644 --- a/tests/transformers/test_tokenizer_common.py +++ b/tests/transformers/test_tokenizer_common.py @@ -947,7 +947,8 @@ def test_maximum_encoding_length_single_input(self): sequence1 = tokenizer(seq_1, return_token_type_ids=None, - add_special_tokens=False) + add_special_tokens=False, + truncation=False) total_length1 = len(sequence1["input_ids"]) self.assertGreater( total_length1, model_max_length, @@ -1080,12 +1081,14 @@ def test_maximum_encoding_length_pair_input(self): sequence1 = tokenizer(seq_1, return_token_type_ids=None, - add_special_tokens=False) + add_special_tokens=False, + truncation=False) total_length1 = len(sequence1["input_ids"]) sequence2 = tokenizer(seq_2, seq_1, return_token_type_ids=None, - add_special_tokens=False) + add_special_tokens=False, + truncation=False) total_length2 = len(sequence2["input_ids"]) self.assertLess( total_length1, model_max_length - 10, @@ -1900,25 +1903,46 @@ def test_call(self): ] # Test not batched - encoded_sequences_1 = tokenizer.encode(sequences[0]) - encoded_sequences_2 = tokenizer(sequences[0]) + encoded_sequences_1 = tokenizer.encode( + sequences[0], + return_token_type_ids=False, + return_attention_mask=True) + encoded_sequences_2 = tokenizer(sequences[0], + return_token_type_ids=False, + return_attention_mask=True) self.assertEqual(encoded_sequences_1, encoded_sequences_2) # Test not batched pairs - encoded_sequences_1 = tokenizer.encode(sequences[0], - sequences[1]) - encoded_sequences_2 = tokenizer(sequences[0], sequences[1]) + encoded_sequences_1 = tokenizer.encode( + sequences[0], + sequences[1], + return_token_type_ids=False, + return_attention_mask=True) + encoded_sequences_2 = tokenizer(sequences[0], + sequences[1], + return_token_type_ids=False, + return_attention_mask=True) self.assertEqual(encoded_sequences_1, encoded_sequences_2) # Test batched - encoded_sequences_1 = tokenizer.batch_encode(sequences) - encoded_sequences_2 = tokenizer(sequences) + encoded_sequences_1 = tokenizer.batch_encode( + sequences, + return_token_type_ids=False, + return_attention_mask=True) + encoded_sequences_2 = tokenizer(sequences, + return_token_type_ids=False, + return_attention_mask=True) self.assertEqual(encoded_sequences_1, encoded_sequences_2) # Test batched pairs encoded_sequences_1 = tokenizer.batch_encode( - list(zip(sequences, sequences))) - encoded_sequences_2 = tokenizer(sequences, sequences) + list(zip(sequences, sequences)), + return_token_type_ids=False, + return_attention_mask=True) + encoded_sequences_2 = tokenizer(sequences, + sequences, + return_token_type_ids=False, + return_attention_mask=True) self.assertEqual(encoded_sequences_1, encoded_sequences_2) def test_batch_encode_plus_batch_sequence_length(self):