Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use fast HF tokenizer. #482

Merged
merged 32 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
651463a
Add option to use fast HF tokenizer
PhilipMay Aug 1, 2020
a433483
Hand merge tests from PR #205
PhilipMay Aug 1, 2020
c20c5db
test_inferencer_with_fast_bert_tokenizer
PhilipMay Aug 1, 2020
5f2b5ee
test_fast_bert_tokenizer
PhilipMay Aug 1, 2020
fa3bd67
test_fast_bert_tokenizer_strip_accents
PhilipMay Aug 1, 2020
cd7298c
test_fast_electra_tokenizer
PhilipMay Aug 1, 2020
01e5ffb
Fix OOM issue of CI
PhilipMay Aug 1, 2020
42f345f
Extend test for fast tokenizer
PhilipMay Aug 2, 2020
9b021ff
test_fast_tokenizer for more model typed
PhilipMay Aug 2, 2020
86d7fd5
Fix tokenize_with_metadata
PhilipMay Aug 2, 2020
a8f4638
Split tokenizer tests
PhilipMay Aug 2, 2020
cdccafa
Fix pytest params bug in test_tok
PhilipMay Aug 2, 2020
47d4b6a
Fix fast tokenizer usage
PhilipMay Aug 4, 2020
8318063
add missing newline eof
PhilipMay Aug 4, 2020
8c61e3b
Add test fast tok. doc_callif.
PhilipMay Aug 4, 2020
aec7d2d
Remove RobertaTokenizerFast
PhilipMay Aug 4, 2020
75ea9dd
Fix Tokenizer load and save.
PhilipMay Aug 4, 2020
2d2cd00
Fix typo
PhilipMay Aug 4, 2020
8afa136
Improve test test_embeddings_extraction
PhilipMay Aug 5, 2020
042fde0
Dosctring for fast tokenizers improved
PhilipMay Aug 5, 2020
7ed385f
tokenizer_args docstring
PhilipMay Aug 5, 2020
d4eb59c
Extend test_embeddings_extraction to fast tok.
PhilipMay Aug 5, 2020
4f87604
extend test_ner with fast tok.
PhilipMay Aug 5, 2020
bc7abca
fix sample_to_features_ner for fast tokenizer
tholor Aug 6, 2020
da9c2f5
temp fix for is_pretokenized until fixed upstream
tholor Aug 6, 2020
19cc211
Make use of fast tokenizer possible + fix bug in offset calculation
bogdankostic Aug 25, 2020
6d0a3c1
Merge branch 'master' into add_fast_tokenizer
bogdankostic Aug 25, 2020
7e75de1
Make fast tokenization possible with NER, LM and QA
bogdankostic Aug 31, 2020
0e4b1b0
Merge remote-tracking branch 'origin/add_fast_tokenizer' into add_fas…
bogdankostic Aug 31, 2020
eb46629
Change error messages
bogdankostic Sep 1, 2020
06d51c0
Add tests
bogdankostic Sep 1, 2020
1acaff4
update error messages, comments and truncation arg in tokenizer
tholor Sep 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 126 additions & 36 deletions farm/data_handler/input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import logging
import re
import collections
from dotmap import DotMap
import numpy as np
Expand Down Expand Up @@ -36,18 +37,34 @@ def sample_to_features_text(
:rtype: list
"""

#TODO It might be cleaner to adjust the data structure in sample.tokenized
# Verify if this current quickfix really works for pairs
tokens_a = sample.tokenized["tokens"]
tokens_b = sample.tokenized.get("tokens_b", None)

inputs = tokenizer.encode_plus(
tokens_a,
tokens_b,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
return_token_type_ids=True
)
if tokenizer.is_fast:
text = sample.clear_text["text"]
# Here, we tokenize the sample for the second time to get all relevant ids
# This should change once we git rid of FARM's tokenize_with_metadata()
inputs = tokenizer(text,
return_token_type_ids=True,
max_length=max_seq_len,
return_special_tokens_mask=True)

if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != len(sample.tokenized["tokens"]):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
# TODO It might be cleaner to adjust the data structure in sample.tokenized
tokens_a = sample.tokenized["tokens"]
tokens_b = sample.tokenized.get("tokens_b", None)

inputs = tokenizer.encode_plus(
tokens_a,
tokens_b,
add_special_tokens=True,
truncation=False, # truncation_strategy is deprecated
return_token_type_ids=True,
max_length=max_seq_len,
is_pretokenized=False,
)

input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"]

Expand Down Expand Up @@ -136,13 +153,30 @@ def samples_to_features_ner(
"""

tokens = sample.tokenized["tokens"]
inputs = tokenizer.encode_plus(text=tokens,
text_pair=None,
add_special_tokens=True,
truncation_strategy='do_not_truncate', # We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)

if tokenizer.is_fast:
text = sample.clear_text["text"]
# Here, we tokenize the sample for the second time to get all relevant ids
# This should change once we git rid of FARM's tokenize_with_metadata()
inputs = tokenizer(text,
return_token_type_ids=True,
max_length=max_seq_len,
return_special_tokens_mask=True)

if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != len(sample.tokenized["tokens"]):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata().\n"
f"Further processing is likely to be wrong!")
else:
inputs = tokenizer.encode_plus(text=tokens,
text_pair=None,
add_special_tokens=True,
truncation=False,
return_special_tokens_mask=True,
return_token_type_ids=True,
is_pretokenized=False
)

input_ids, segment_ids, special_tokens_mask = inputs["input_ids"], inputs["token_type_ids"], inputs["special_tokens_mask"]

Expand Down Expand Up @@ -231,6 +265,14 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T

tokens_b, t2_label = mask_random_words(tokens_b, tokenizer.vocab,
token_groups=sample.tokenized["text_b"]["start_of_word"])

if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
tokens_a = " ".join(tokens_a)
tokens_a = re.sub(r"(^|\s)(##)", "", tokens_a)
tokens_b = " ".join(tokens_b)
tokens_b = re.sub(r"(^|\s)(##)", "", tokens_b)

# convert lm labels to ids
t1_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t1_label]
t2_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t2_label]
Expand All @@ -246,18 +288,39 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
tokens_b = None
tokens_a, t1_label = mask_random_words(tokens_a, tokenizer.vocab,
token_groups=sample.tokenized["text_a"]["start_of_word"])
if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
tokens_a = " ".join(tokens_a)
tokens_a = re.sub(r"(^|\s)(##)", "", tokens_a)

# convert lm labels to ids
lm_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t1_label]

# encode string tokens to input_ids and add special tokens
inputs = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
# We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)
if tokenizer.is_fast:
inputs = tokenizer(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
return_special_tokens_mask=True,
return_token_type_ids=True)

seq_b_len = len(sample.tokenized["text_b"]["tokens"]) if "text_b" in sample.tokenized else 0
if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != \
(len(sample.tokenized["text_a"]["tokens"]) + seq_b_len):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
# encode string tokens to input_ids and add special tokens
inputs = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
truncation=False,
truncation_strategy='do_not_truncate',
# We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)

input_ids, segment_ids, special_tokens_mask = inputs["input_ids"], inputs["token_type_ids"], inputs[
"special_tokens_mask"]
Expand Down Expand Up @@ -358,12 +421,35 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks
# (question_len_t + passage_len_t + n_special_tokens). This may be less than max_seq_len but will not be greater
# than max_seq_len since truncation was already performed when the document was chunked into passages
# (c.f. create_samples_squad() )
encoded = tokenizer.encode_plus(text=sample.tokenized["question_tokens"],
text_pair=sample.tokenized["passage_tokens"],
add_special_tokens=True,
truncation_strategy='do_not_truncate',
return_token_type_ids=True,
return_tensors=None)

if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
question_tokens = " ".join(question_tokens)
question_tokens = re.sub(r"(^|\s)(##)", "", question_tokens)
passage_tokens = " ".join(passage_tokens)
passage_tokens = re.sub(r"(^|\s)(##)", "", passage_tokens)

encoded = tokenizer(text=question_tokens,
text_pair=passage_tokens,
add_special_tokens=True,
return_special_tokens_mask=True,
return_token_type_ids=True)

if (len(encoded["input_ids"]) - encoded["special_tokens_mask"].count(1)) != \
(len(sample.tokenized["question_tokens"]) + len(sample.tokenized["passage_tokens"])):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(encoded['input_ids']) - encoded['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
encoded = tokenizer.encode_plus(text=sample.tokenized["question_tokens"],
text_pair=sample.tokenized["passage_tokens"],
add_special_tokens=True,
truncation=False,
truncation_strategy='do_not_truncate',
return_token_type_ids=True,
return_tensors=None)

input_ids = encoded["input_ids"]
segment_ids = encoded["token_type_ids"]

Expand Down Expand Up @@ -467,8 +553,12 @@ def combine_vecs(question_vec, passage_vec, tokenizer, spec_tok_val=-1):
# Join question_label_vec and passage_label_vec and add slots for special tokens
vec = tokenizer.build_inputs_with_special_tokens(token_ids_0=question_vec,
token_ids_1=passage_vec)
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=question_vec,
token_ids_1=passage_vec)
if tokenizer.is_fast:
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=vec,
already_has_special_tokens=True)
else:
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=question_vec,
token_ids_1=passage_vec)

# If a value in vec corresponds to a special token, it will be replaced with spec_tok_val
combined = [v if not special_token else spec_tok_val for v, special_token in zip(vec, spec_toks_mask)]
Expand Down
6 changes: 5 additions & 1 deletion farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def save(self, save_dir):
config = self.generate_config()
# save tokenizer incl. attributes
config["tokenizer"] = self.tokenizer.__class__.__name__
self.tokenizer.save_pretrained(save_dir)

# Because the fast tokenizers expect a str and not Path
# always convert Path to str here.
self.tokenizer.save_pretrained(str(save_dir))

# save processor
config["processor"] = self.__class__.__name__
output_config_file = Path(save_dir) / "processor_config.json"
Expand Down
20 changes: 19 additions & 1 deletion farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def load(
s3e_stats=None,
num_processes=None,
disable_tqdm=False,
tokenizer_class=None,
use_fast=False,
tokenizer_args=None,
dummy_ph=False,
benchmarking=False,

Expand Down Expand Up @@ -212,6 +215,15 @@ def load(
:type num_processes: int
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
:type disable_tqdm: bool
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
:type tokenizer_class: str
:param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
use the Python one (False).
:param tokenizer_args: (Optional) Will be passed to the Tokenizer ``__init__`` method.
PhilipMay marked this conversation as resolved.
Show resolved Hide resolved
See https://huggingface.co/transformers/main_classes/tokenizer.html and detailed tokenizer documentation
on `Hugging Face Transformers <https://huggingface.co/transformers/>`_.
:type tokenizer_args: dict
:type use_fast: bool
:param dummy_ph: If True, methods of the prediction head will be replaced
with a dummy method. This is used to isolate lm run time from ph run time.
:type dummy_ph: bool
Expand All @@ -223,6 +235,8 @@ def load(
:return: An instance of the Inferencer.

"""
if tokenizer_args is None:
tokenizer_args = {}

device, n_gpu = initialize_device_settings(use_cuda=gpu, local_rank=-1, use_amp=None)
name = os.path.basename(model_name_or_path)
Expand Down Expand Up @@ -250,7 +264,11 @@ def load(

model = AdaptiveModel.convert_from_transformers(model_name_or_path, device, task_type)
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = Tokenizer.load(model_name_or_path)
tokenizer = Tokenizer.load(model_name_or_path,
tokenizer_class=tokenizer_class,
use_fast=use_fast,
**tokenizer_args,
)

# TODO infer task_type automatically from config (if possible)
if task_type == "question_answering":
Expand Down
Loading