-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
huggingface.py
694 lines (623 loc) · 28.2 KB
/
huggingface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
import math
import torch
import torch.nn.functional as F
import transformers
import peft
from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from lm_eval import utils
from lm_eval.base import BaseLM
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
_DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {}
if max_memory_per_gpu is not None:
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu
for device_idx in range(torch.cuda.device_count())
}
max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args = {}
if max_memory:
args["max_memory"] = max_memory
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
def _get_dtype(
dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible."""
if dtype is None and config is not None:
_torch_dtype = config.torch_dtype
elif isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HuggingFaceAutoLM(BaseLM):
AUTO_CONFIG_CLASS: transformers.AutoConfig = transformers.AutoConfig
AUTO_TOKENIZER_CLASS: transformers.AutoTokenizer = transformers.AutoTokenizer
AUTO_MODEL_CLASS: transformers.AutoModel = None
AUTO_PEFT_CLASS: peft.PeftModel = None
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
_DEFAULT_MAX_LENGTH: int = 2048
def __init__(
self,
pretrained: str,
tokenizer: Optional[str] = None,
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[int] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
use_accelerate: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[int, str]] = "cuda",
peft: str = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
pretrained (str):
The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
WARNING: Evaluating causal models with `add_special_tokens=True` is
currently __not__ supported.
> Large model loading `accelerate` arguments
use_accelerate (bool, optional, defaults to False):
If True, uses the `accelerate` library to load a large model across
multiple devices.
device_map_option (str, optional, defaults to "auto"):
The device map option to use when loading the model with
`accelerate`.
Options:
"auto", "balanced", "balanced_low_0", "sequential"
See the `accelerate` docs for more details on these options:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.device_map
max_memory_per_gpu (Union[int, str], optional, defaults to None):
The maximum memory available for each GPU in bytes as `int` or in
the format f"{significand}{unit_symbol}" where {unit_symbol} is
any of ["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in
the "Parameters for big model inference" section of the following
docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
max_cpu_memory (Union[int, str], optional, defaults to None):
The maximum available CPU RAM in bytes as `int` or in the format
f"{significand}{unit_symbol}" where {unit_symbol} is any of
["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in the
"Parameters for big model inference" section of the following docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
offload_folder (str, optional, defaults to "./offload"):
The folder to offload weights into if `device_map` contains any
"disk" value.
dtype (Union[str, torch.dtype], optional, defaults to None):):
Converts the model weights to `dtype`, if specified. Strings get
converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`).
Use `dtype="auto"` to derive the type from the model’s weights.
peft (str, optional, defaults to None):
Path of the adapter weights to load from Huggingface. This will usually
include a directory that includes the files `adapter_config.json` and
`adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft)
load_in_8bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-8bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit
trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model.
"""
super().__init__()
assert isinstance(pretrained, str)
assert isinstance(device, str)
assert isinstance(batch_size, int)
if (
add_special_tokens is not None
and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM
):
# TODO: Support evaluating causal models with special tokens. Currently,
# this is not possible because the `_loglikelihood_tokens()` method for
# causal LMs makes a no-special-tokens assumption given that contexts
# and labels/continuations are tokenized separately without special
# tokens, concatenated, and then processed as inputs.
assert (
not add_special_tokens
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
self._batch_size = batch_size # TODO: Adaptive batch size
self._max_gen_toks = max_gen_toks
self._max_length = max_length
self._config = self.AUTO_CONFIG_CLASS.from_pretrained(
pretrained,
trust_remote_code=trust_remote_code,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
self._add_special_tokens = add_special_tokens
self.tokenizer = self._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
self.tokenizer.model_max_length = self.max_length
model_kwargs = {}
if use_accelerate:
model_kwargs = _get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model(
pretrained=pretrained,
trust_remote_code=trust_remote_code,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
if peft is not None:
self.model = self._create_auto_model_peft(
model=self.model,
peft=peft,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**model_kwargs,
)
self.model.eval()
torch.set_grad_enabled(False)
self._device = device
if use_accelerate and "lm_head" in self.model.hf_device_map:
# `accelerate` can place `lm_head` weights on a different device than
# the user specified one so we force `self._device` to be the same as
# `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate:
self.model.to(self._device)
def _create_auto_model(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
return model
def _create_auto_model_peft(
self,
*,
model: transformers.PreTrainedModel,
peft: str,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
):
model = self.AUTO_PEFT_CLASS.from_pretrained(
model,
peft,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
return model
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
@property
def add_special_tokens(self) -> bool:
"""Whether to include special tokens in encoded text. This should be
determined by whether or not the model was trained with special tokens.
TODO: Remove these conditionals once HuggingFace supports a way to
check whether or not an arbitrary model was trained with special tokens.
"""
if self._add_special_tokens is not None:
return self._add_special_tokens
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM:
return False
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM:
return True
else:
raise ValueError(
"Could not determine `add_special_tokens` value from the model "
"class. Set to `True` or `False` depending on whether the model "
"was pre-trained with special tokens."
)
@property
def eot_token(self) -> str:
return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def max_gen_toks(self) -> int:
return self._max_gen_toks
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
NOTE: Different model configurations have different max sequence length
attribute names.
- n_positions: (CTRLConfig)
- max_position_embeddings: (BartConfig, RoFormerConfig)
- n_ctx: (GPT2Config)
NOTE: For relative position encoded models you should specify the max
sequence length of the model in the constructor via `max_length`.
"""
if self._max_length is not None:
return self._max_length
# Try to get the sequence length from the model config.
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def batch_size(self) -> int:
# TODO: Add adaptive batch size.
return self._batch_size # * gpus
@property
def device(self) -> Union[int, str, torch.device]:
return self._device
def tok_encode(self, string: str) -> TokenSequence:
# TODO: Merge `tok_encode_batch` here.
return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens)
def tok_encode_batch(self, strings: List[str]) -> TokenSequence:
return self.tokenizer(
strings,
padding=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
def greedy_until(self, requests: List[Tuple[str, Union[List[str], str]]]) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False), self.batch_size
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get('until', None)
stop_sequences = [stop] if isinstance(stop, list) else stop
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
token_context = self.tok_encode_batch(context)
responses = self._model_generate(
inputs=token_context,
max_tokens=max_tokens,
stop=until,
)
responses = self.tok_decode(responses.tolist())
for response in responses:
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
class AutoCausalLM(HuggingFaceAutoLM):
"""Causal language modeling.
You can find a set of supported models in the HF documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForCausalLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
AUTO_PEFT_CLASS = peft.PeftModel
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
tokenizer = super()._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
tokenizer.padding_side = "left"
return tokenizer
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(inputs)["logits"]
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
# Ensure that the context does not encroach into the `space`
# for the generation.
input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :]
attention_mask = inputs["attention_mask"][
:, self.max_gen_toks - self.max_length :
]
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# GPT style models require the `generate` `max_length` arg to include the
# context length, so we instead set `max_new_tokens` which is the number
# of new tokens to generate, excluding the current number of tokens.
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return utils.select_continuation_from_batch_left_padding(
generations, max_context_size=inputs["input_ids"].size(1)
)
class AutoSeq2SeqLM(HuggingFaceAutoLM):
"""Seq2Seq language modeling.
You can find a set of supported models in the following documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForSeq2SeqLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
AUTO_PEFT_CLASS = peft.PeftModel
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
TODO: Currently only works for relative position encoded Seq2Seq models.
"""
if self._max_length is not None:
return self._max_length
return self._DEFAULT_MAX_LENGTH
def loglikelihood(
self, requests: List[Tuple[str, str]]
) -> List[Tuple[float, bool]]:
new_requests = []
for chunk in utils.chunks(requests, self.batch_size):
context, continuation = zip(*chunk)
# Fill empty contexts with the EOT token.
context = [
f"{self.eot_token}" if len(text) == 0 else text for text in context
]
context_enc = self.tok_encode_batch(context)
for key in context_enc:
context_enc[key] = context_enc[key][:, -self.max_length :]
# Remove leading whitespace introduced by the default
# `text_target_separator` since the context and continuation
# will not be concatenated as a single (decoder) input.
continuation = [text.lstrip() for text in continuation]
continuation_enc = self.tok_encode_batch(list(continuation))
for key in continuation_enc:
continuation_enc[key] = continuation_enc[key][:, -self.max_length :]
new_requests.append(
((context, continuation), context_enc, continuation_enc)
)
return self._loglikelihood_tokens(new_requests)
def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]:
loglikelihoods = []
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
contexts, conts = utils.split_and_pad_windows(
rolling_token_windows,
pad_token_id=self.eot_token_id,
max_seq_len=self.max_length,
)
# Manually create BatchEncoding tensors with attention masks as
# expected by `self._model_call` in `self._loglikelihood_tokens`.
contexts_enc = torch.Tensor(contexts).long()
contexts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": contexts_enc,
"attention_mask": (contexts_enc != self.eot_token_id).long(),
}
)
conts_enc = torch.Tensor(conts).long()
conts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": conts_enc,
"attention_mask": (conts_enc != self.eot_token_id).long(),
}
)
# TODO: Extract out this call so it only gets called once and also
# somehow figure out partial caching for.
rolling_token_windows_request = [
((contexts, conts), contexts_enc, conts_enc)
]
string_nll = self._loglikelihood_tokens(
rolling_token_windows_request, disable_tqdm=True
)
string_nll = [x[0] for x in string_nll] # discard is_greedy
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], TokenSequence, TokenSequence]],
disable_tqdm: Optional[bool] = False,
) -> List[Tuple[float, bool]]:
results = []
for chunk in tqdm(
requests, total=math.ceil(len(requests)), disable=disable_tqdm
):
cache_keys, inputs_tokens, targets_tokens = chunk
inputs_tokens = inputs_tokens.to(self.device)
targets_tokens = targets_tokens.to(self.device)
outputs = self._model_call(inputs=inputs_tokens, labels=targets_tokens)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
zip(cache_keys[0], cache_keys[1]),
log_softmaxes,
targets_tokens["input_ids"],
targets_tokens["attention_mask"],
)
for cache_key, log_softmax, target_tokens, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tokens = target_tokens[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tokens).all()
target_logits = torch.gather(
log_softmax, 1, target_tokens.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
results.append(answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return results
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(**inputs, labels=labels["input_ids"])
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
input_ids = inputs["input_ids"][:, -self.max_length :].to(self.device)
attention_mask = inputs["attention_mask"][:, -self.max_length :].to(self.device)
# Generate one token to calculate the number of start tokens prepended to decoder_input_ids
# (leaving this here in case the below assumption is violated in the future)
# one_tok_gen = self.model.generate(
# input_ids=torch.zeros((1, 1), dtype=torch.int),
# min_length=2,
# max_new_tokens=1,
# ).squeeze()
# initial_decoder_input_length = len(one_tok_gen) - 1
# Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)