-
Notifications
You must be signed in to change notification settings - Fork 62
/
modeling.py
936 lines (792 loc) · 37.5 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeuronModelForXXX classes for inference on neuron devices using the same API as Transformers."""
import copy
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import (
GenerationMixin,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.utils import ModelOutput
from .generation import TokenSelector
from .modeling_base import NeuronBaseModel
from .modeling_decoder import NeuronDecoderModel
if TYPE_CHECKING:
from pathlib import Path
from tempfile import TemporaryDirectory
from transformers import GenerationConfig, PretrainedConfig
from transformers.generation import StoppingCriteriaList
logger = logging.getLogger(__name__)
_TOKENIZER_FOR_DOC = "AutoTokenizer"
NEURON_MODEL_START_DOCSTRING = r"""
This model inherits from [`~neuron.modeling.NeuronBaseModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving)
Args:
config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`optimum.neuron.modeling.NeuronBaseModel.from_pretrained`] method to load the model weights.
model (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler.
"""
NEURON_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer).
See [`PreTrainedTokenizer.encode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.encode) and
[`PreTrainedTokenizer.__call__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.__call__) for details.
[What are input IDs?](https://huggingface.co/docs/transformers/glossary#input-ids)
attention_mask (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](https://huggingface.co/docs/transformers/glossary#attention-mask)
token_type_ids (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- 1 for tokens that are **sentence A**,
- 0 for tokens that are **sentence B**.
[What are token type IDs?](https://huggingface.co/docs/transformers/glossary#token-type-ids)
"""
NEURON_IMAGE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`):
Pixel values corresponding to the images in the current batch.
Pixel values can be obtained from encoded images using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor).
"""
FEATURE_EXTRACTION_EXAMPLE = r"""
Example of feature extraction:
*(Following model is compiled with neuronx compiler and can only be run on INF2. Replace "neuronx" with "neuron" if you are using INF1.)*
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Dear Evan Hansen is the winner of six Tony Awards.", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> list(last_hidden_state.shape)
[1, 13, 384]
```
"""
@add_start_docstrings(
"""
Neuron Model with a BaseModelOutput for feature-extraction tasks.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForFeatureExtraction(NeuronBaseModel):
"""
Feature Extraction model on Neuron devices.
"""
auto_model_class = AutoModel
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ FEATURE_EXTRACTION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForFeatureExtraction",
checkpoint="optimum/all-MiniLM-L6-v2-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs)
# last_hidden_state -> (batch_size, sequencen_len, hidden_size)
last_hidden_state = self.remove_padding(
[outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
)[
0
] # Remove padding on batch_size(0), and sequence_length(1)
if len(outputs) > 1:
# pooler_output -> (batch_size, hidden_size)
pooler_output = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
else:
pooler_output = None
return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output)
SENTENCE_TRANSFORMERS_EXAMPLE = r"""
Example of TEXT Sentence Transformers:
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("In the smouldering promise of the fall of Troy, a mythical world of gods and mortals rises from the ashes.", return_tensors="pt")
>>> outputs = model(**inputs)
>>> token_embeddings = outputs.token_embeddings
>>> sentence_embedding = = outputs.sentence_embedding
```
"""
@add_start_docstrings(
"""
Neuron Model for Sentence Transformers.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForSentenceTransformers(NeuronBaseModel):
"""
Sentence Transformers model on Neuron devices.
"""
auto_model_class = AutoModel
library_name = "sentence_transformers"
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ SENTENCE_TRANSFORMERS_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForSentenceTransformers",
checkpoint="optimum/bge-base-en-v1.5-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
model_type = self.config.neuron["model_type"]
neuron_inputs = {"input_ids": input_ids}
if pixel_values is not None:
neuron_inputs["pixel_values"] = pixel_values
neuron_inputs["attention_mask"] = (
attention_mask # The input order for clip is: input_ids, pixel_values, attention_mask.
)
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs)
if "clip" in model_type:
text_embeds = self.remove_padding([outputs[0]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
image_embeds = self.remove_padding([outputs[1]], dims=[0], indices=[pixel_values.shape[0]])[
0
] # Remove padding on batch_size(0)
return ModelOutput(text_embeds=text_embeds, image_embeds=image_embeds)
else:
# token_embeddings -> (batch_size, sequencen_len, hidden_size)
token_embeddings = self.remove_padding(
[outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
)[
0
] # Remove padding on batch_size(0), and sequence_length(1)
# sentence_embedding -> (batch_size, hidden_size)
sentence_embedding = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
return ModelOutput(token_embeddings=token_embeddings, sentence_embedding=sentence_embedding)
MASKED_LM_EXAMPLE = r"""
Example of fill mask:
*(Following model is compiled with neuronx compiler and can only be run on INF2. Replace "neuronx" with "neuron" if you are using INF1.)*
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("This [MASK] Agreement is between General Motors and John Murray.", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 13, 30522]
```
"""
@add_start_docstrings(
"""
Neuron Model with a MaskedLMOutput for masked language modeling tasks.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForMaskedLM(NeuronBaseModel):
"""
Masked language model for on Neuron devices.
"""
auto_model_class = AutoModelForMaskedLM
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ MASKED_LM_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForMaskedLM",
checkpoint="optimum/legal-bert-base-uncased-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: (batch_size, sequencen_len, vocab_size)
outputs = self.remove_padding(
outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
) # Remove padding on batch_size(0), and sequence_length(1)
logits = outputs[0]
return MaskedLMOutput(logits=logits)
QUESTION_ANSWERING_EXAMPLE = r"""
Example of question answering:
*(Following model is compiled with neuronx compiler and can only be run on INF2.)*
```python
>>> import torch
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> question, text = "Are there wheelchair spaces in the theatres?", "Yes, we have reserved wheelchair spaces with a good view."
>>> inputs = tokenizer(question, text, return_tensors="pt")
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([12])
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
```
"""
@add_start_docstrings(
"""
Neuron Model with a QuestionAnsweringModelOutput for extractive question-answering tasks like SQuAD.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForQuestionAnswering(NeuronBaseModel):
"""
Question Answering model on Neuron devices.
"""
auto_model_class = AutoModelForQuestionAnswering
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ QUESTION_ANSWERING_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForQuestionAnswering",
checkpoint="optimum/roberta-base-squad2-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, sequence_length]
outputs = self.remove_padding(
outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
) # Remove padding on batch_size(0), and sequence_length(1)
start_logits = outputs[0]
end_logits = outputs[1]
return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
SEQUENCE_CLASSIFICATION_EXAMPLE = r"""
Example of single-label classification:
*(Following model is compiled with neuronx compiler and can only be run on INF2.)*
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hamilton is considered to be the best musical of human history.", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
```
"""
@add_start_docstrings(
"""
Neuron Model with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForSequenceClassification(NeuronBaseModel):
"""
Sequence Classification model on Neuron devices.
"""
auto_model_class = AutoModelForSequenceClassification
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ SEQUENCE_CLASSIFICATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForSequenceClassification",
checkpoint="optimum/distilbert-base-uncased-finetuned-sst-2-english-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, num_labels]
outputs = self.remove_padding(
outputs, dims=[0], indices=[input_ids.shape[0]]
) # Remove padding on batch_size(0)
logits = outputs[0]
return SequenceClassifierOutput(logits=logits)
TOKEN_CLASSIFICATION_EXAMPLE = r"""
Example of token classification:
*(Following model is compiled with neuronx compiler and can only be run on INF2.)*
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Lin-Manuel Miranda is an American songwriter, actor, singer, filmmaker, and playwright.", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 20, 9]
```
"""
@add_start_docstrings(
"""
Neuron Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForTokenClassification(NeuronBaseModel):
"""
Token Classification model on Neuron devices.
"""
auto_model_class = AutoModelForTokenClassification
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ TOKEN_CLASSIFICATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForTokenClassification",
checkpoint="optimum/bert-base-NER-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
# run inference
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, sequence_length, num_labels]
outputs = self.remove_padding(
outputs, dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
) # Remove padding on batch_size(0), and sequence_length(-1)
logits = outputs[0]
return TokenClassifierOutput(logits=logits)
MULTIPLE_CHOICE_EXAMPLE = r"""
Example of mutliple choice:
*(Following model is compiled with neuronx compiler and can only be run on INF2.)*
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> num_choices = 4
>>> first_sentence = ["Members of the procession walk down the street holding small horn brass instruments."] * num_choices
>>> second_sentence = [
... "A drum line passes by walking down the street playing their instruments.",
... "A drum line has heard approaching them.",
... "A drum line arrives and they're outside dancing and asleep.",
... "A drum line turns the lead singer watches the performance."
... ]
>>> inputs = tokenizer(first_sentence, second_sentence, truncation=True, padding=True)
# Unflatten the inputs values expanding it to the shape [batch_size, num_choices, seq_length]
>>> for k, v in inputs.items():
... inputs[k] = [v[i: i + num_choices] for i in range(0, len(v), num_choices)]
>>> inputs = dict(inputs.convert_to_tensors(tensor_type="pt"))
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> logits.shape
[1, 4]
```
"""
@add_start_docstrings(
"""
Neuron Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForMultipleChoice(NeuronBaseModel):
"""
Multiple choice model on Neuron devices.
"""
auto_model_class = AutoModelForMultipleChoice
@add_start_docstrings_to_model_forward(
NEURON_TEXT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+ MULTIPLE_CHOICE_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="NeuronModelForMultipleChoice",
checkpoint="optimum/bert-base-uncased_SWAG-neuronx",
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
neuron_inputs["token_type_ids"] = token_type_ids
# run inference
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, num_choices]
outputs = self.remove_padding(
outputs, dims=[0, -1], indices=[input_ids.shape[0], input_ids.shape[1]]
) # Remove padding on batch_size(0), and num_choices(-1)
logits = outputs[0]
return MultipleChoiceModelOutput(logits=logits)
NEURON_CAUSALLM_MODEL_START_DOCSTRING = r"""
This model inherits from [`~neuron.modeling.NeuronDecoderModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving)
Args:
model (`torch.nn.Module`): [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) is the neuron decoder graph.
config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model.
model_path (`Path`): The directory where the compiled artifacts for the model are stored.
It can be a temporary directory if the model has never been saved locally before.
generation_config (`transformers.GenerationConfig`): [GenerationConfig](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) holds the configuration for the model generation task.
"""
NEURON_CAUSALLM_MODEL_FORWARD_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
cache_ids (`torch.LongTensor`): The indices at which the cached key and value for the current inputs need to be stored.
start_ids (`torch.LongTensor`): The indices of the first tokens to be processed, deduced form the attention masks.
"""
TEXT_GENERATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> inputs = tokenizer("My favorite moment of the day is", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs, do_sample=True, temperature=0.9, min_length=20, max_length=20)
>>> tokenizer.batch_decode(gen_tokens) # doctest: +IGNORE_RESULT
```
"""
@add_start_docstrings(
r"""
Neuron model with a causal language modeling head for inference on Neuron devices.
""",
NEURON_CAUSALLM_MODEL_START_DOCSTRING,
)
class NeuronModelForCausalLM(NeuronDecoderModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
main_input_name = "input_ids"
def __init__(
self,
config: "PretrainedConfig",
checkpoint_dir: Union[str, "Path", "TemporaryDirectory"],
compiled_dir: Optional[Union[str, "Path", "TemporaryDirectory"]] = None,
generation_config: Optional["GenerationConfig"] = None,
):
super().__init__(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config)
self.batch_size = self.config.neuron["batch_size"]
self.max_length = self.config.neuron["sequence_length"]
self.continuous_batching = self.model.neuron_config and self.model.neuron_config.continuous_batching
# The generate method from GenerationMixin expects the device attribute to be set
self.device = torch.device("cpu")
def reset_generation(self):
pass
@add_start_docstrings_to_model_forward(
NEURON_CAUSALLM_MODEL_FORWARD_DOCSTRING
+ TEXT_GENERATION_EXAMPLE.format(
processor_class="AutoTokenizer",
model_class="NeuronModelForCausalLM",
checkpoint="gpt2",
)
)
def forward(
self,
input_ids: torch.Tensor,
cache_ids: torch.Tensor,
start_ids: torch.Tensor = None,
return_dict: bool = True,
):
# Evaluate the output logits, storing the current key and values at the indices specified by cache_ids
out_logits = self.model.forward(input_ids, cache_ids, start_ids)
out_logits = out_logits[:, None, :]
# Since we are using a static cache, we don't need to return past keys and values
if return_dict:
return ModelOutput([("logits", out_logits)])
return (out_logits,)
def get_start_ids(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[torch.Tensor] = None,
):
# The start_ids parameter has different meanings:
# - for continuous (unpadded) batching it corresponds to the sequence id,
# - for static batching it corresponds to the start of the padded sequence.
if self.continuous_batching:
if seq_ids is None:
seq_ids = torch.arange(input_ids.shape[0])
else:
assert seq_ids.shape[0] == input_ids.shape[0]
return seq_ids
start_ids = None
if attention_mask is not None:
_, start_ids = attention_mask.max(axis=1)
return start_ids
def get_cache_ids(self, attention_mask: torch.tensor, prefill: bool):
cache_n, cache_len = attention_mask.shape
if self.continuous_batching:
# Evaluate the inputs that are not masked for each sequence
input_length = attention_mask.sum(axis=1)
if not prefill:
# When decoding, cache_ids contains a single value per sequence
return (input_length - 1).unsqueeze(1)
# When prefilling, cache_ids is an increasing range
cache_ids = torch.zeros_like(attention_mask)
for i in range(cache_n):
cur_length = input_length[i]
cache_ids[i, :cur_length] = torch.arange(cur_length)
return cache_ids
# Static batching
return None if prefill else torch.tensor([cache_len - 1], dtype=torch.int32)
def prepare_inputs_for_prefill(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, seq_ids: Optional[List[int]] = None
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=True)
if self.continuous_batching and torch.any(attention_mask[:, 0] == 0):
# Inputs are left padded: we need to invert padding as continuous batching requires right-padding
batch_size, seq_len = input_ids.shape
input_length = attention_mask.sum(axis=1)
new_input_ids = torch.zeros_like(input_ids)
for i in range(batch_size):
cur_length = input_length[i]
new_input_ids[i, :cur_length] = input_ids[i, seq_len - cur_length :]
input_ids = new_input_ids
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}
def prepare_inputs_for_decode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[List[int]] = None,
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=False)
# Only pass the last tokens of each sample
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}
def can_generate(self) -> bool:
"""Returns True to validate the check made in `GenerationMixin.generate()`."""
return True
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
generation_config: Optional["GenerationConfig"] = None,
stopping_criteria: Optional["StoppingCriteriaList"] = None,
**kwargs,
) -> torch.LongTensor:
r"""
A streamlined generate() method overriding the transformers.GenerationMixin.generate() method.
This method uses the same logits processors/warpers and stopping criterias as the transformers library
`generate()` method but restricts the generation to greedy search and sampling.
It does not support transformers `generate()` advanced options.
Please refer to https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
for details on generation configuration.
Parameters:
input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~transformers.generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config.
Returns:
`torch.Tensor`: A `torch.FloatTensor`.
"""
# The actual generation configuration is a combination of config and parameters
generation_config = copy.deepcopy(self.generation_config if generation_config is None else generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# Check model kwargs are actually used by either prepare_inputs_for_generation or forward
self._validate_model_kwargs(model_kwargs)
# Instantiate a TokenSelector for the specified configuration
selector = TokenSelector.create(
input_ids, generation_config, self, self.max_length, stopping_criteria=stopping_criteria
)
# Verify that the inputs are compatible with the model static input dimensions
batch_size, sequence_length = input_ids.shape
if sequence_length > self.max_length:
raise ValueError(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
)
padded_input_ids = input_ids
padded_attention_mask = torch.ones_like(input_ids) if attention_mask is None else attention_mask
if batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
elif batch_size < self.batch_size and not self.continuous_batching:
logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.")
padding_shape = [self.batch_size - batch_size, sequence_length]
pad_token_id = generation_config.pad_token_id
if pad_token_id is None:
if isinstance(self.config.eos_token_id, list):
pad_token_id = self.config.eos_token_id[0]
else:
pad_token_id = self.config.eos_token_id
padding = torch.full(padding_shape, fill_value=pad_token_id, dtype=torch.int64)
padded_input_ids = torch.cat([padded_input_ids, padding])
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([padded_attention_mask, padding])
output_ids = self.generate_tokens(
padded_input_ids,
selector,
batch_size,
padded_attention_mask,
**model_kwargs,
)
return output_ids[:batch_size, :]
def generate_tokens(
self,
input_ids: torch.LongTensor,
selector: TokenSelector,
batch_size: int,
attention_mask: torch.Tensor,
**model_kwargs,
) -> torch.LongTensor:
r"""
Generate tokens using sampling or greedy search.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
selector (`TokenSelector`):
The object implementing the generation logic based on transformers processors and stopping criterias.
batch_size (`int`):
The actual input batch size. Used to avoid generating tokens for padded inputs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens.
"""
# keep track of which sequences are already finished
unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences[:batch_size] = 1
# Prefill and obtain the first token
model_inputs = self.prepare_inputs_for_prefill(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)
# auto-regressive generation
while True:
next_token_logits = outputs.logits[:, -1, :]
next_tokens = selector.select(input_ids, next_token_logits)
# finished sentences should have their next token be a padding token
next_tokens = next_tokens * unfinished_sequences + selector.pad_token_id * (1 - unfinished_sequences)
# update inputs for the next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
# if eos_token was found in one sentence, set sentence to finished
unfinished_sequences = unfinished_sequences * torch.isin(
next_tokens, torch.tensor(selector.eos_token_ids), invert=True
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
break
# stop if we exceed the maximum length
if selector.stopping_criteria(input_ids, None):
break
# forward pass to get next token
model_inputs = self.prepare_inputs_for_decode(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)
return input_ids