-
Notifications
You must be signed in to change notification settings - Fork 62
/
model_configs.py
408 lines (306 loc) · 12.7 KB
/
model_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific Neuron configurations."""
from typing import TYPE_CHECKING, Dict, List
import torch
from ...utils import (
DummySeq2SeqDecoderTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
NormalizedConfigManager,
NormalizedTextAndVisionConfig,
is_diffusers_available,
)
from ..tasks import TasksManager
from .config import (
TextAndVisionNeuronConfig,
TextEncoderNeuronConfig,
TextNeuronDecoderConfig,
VisionNeuronConfig,
)
if TYPE_CHECKING:
if is_diffusers_available():
from diffusers.models.vae import Decoder as VaeDecoder
COMMON_TEXT_TASKS = [
"feature-extraction",
"fill-mask",
"multiple-choice",
"question-answering",
"text-classification",
"token-classification",
]
register_in_tasks_manager = TasksManager.create_register("neuron")
@register_in_tasks_manager("bert", *COMMON_TEXT_TASKS)
class BertNeuronConfig(TextEncoderNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
ATOL_FOR_VALIDATION = 1e-4
@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask", "token_type_ids"]
@register_in_tasks_manager("albert", *COMMON_TEXT_TASKS)
class AlbertNeuronConfig(BertNeuronConfig):
pass
@register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS)
class ConvBertNeuronConfig(BertNeuronConfig):
@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
return ["last_hidden_state"]
return self._TASK_TO_COMMON_OUTPUTS[self.task]
@register_in_tasks_manager("electra", *COMMON_TEXT_TASKS)
class ElectraNeuronConfig(ConvBertNeuronConfig):
pass
@register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS)
class FlaubertNeuronConfig(ConvBertNeuronConfig):
pass
@register_in_tasks_manager("mobilebert", *COMMON_TEXT_TASKS)
class MobileBertNeuronConfig(BertNeuronConfig):
pass
@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ConvBertNeuronConfig):
pass
@register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS)
class XLMNeuronConfig(ConvBertNeuronConfig):
pass
@register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS)
class DistilBertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]
@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
return ["last_hidden_state"]
return self._TASK_TO_COMMON_OUTPUTS[self.task]
@register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS)
class CamembertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]
@register_in_tasks_manager("mpnet", *COMMON_TEXT_TASKS)
class MPNetNeuronConfig(CamembertNeuronConfig):
pass
@register_in_tasks_manager("roberta", *COMMON_TEXT_TASKS)
class RobertaNeuronConfig(CamembertNeuronConfig):
pass
@register_in_tasks_manager("xlm-roberta", *COMMON_TEXT_TASKS)
class XLMRobertaNeuronConfig(CamembertNeuronConfig):
pass
# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta", *COMMON_TEXT_TASKS)
class DebertaNeuronConfig(BertNeuronConfig):
@property
def inputs(self) -> List[str]:
common_inputs = super().inputs
if self._config.type_vocab_size == 0:
# We remove token type ids.
common_inputs.pop(-1)
return common_inputs
# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta-v2", *COMMON_TEXT_TASKS)
class DebertaV2NeuronConfig(DebertaNeuronConfig):
pass
class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"
@register_in_tasks_manager("clip", *["feature-extraction", "zero-shot-image-classification"])
class CLIPNeuronConfig(TextAndVisionNeuronConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
@property
def inputs(self) -> List[str]:
return ["input_ids", "pixel_values", "attention_mask"]
@property
def outputs(self) -> List[str]:
return ["logits_per_image", "logits_per_text", "text_embeds", "image_embeds"]
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"])
class CLIPTextWithProjectionNeuronConfig(TextEncoderNeuronConfig):
MODEL_TYPE = "clip-text-model"
ATOL_FOR_VALIDATION = 1e-3
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
vocab_size="vocab_size",
sequence_length="max_position_embeddings",
num_layers="num_hidden_layers",
allow_new=True,
)
@property
def inputs(self) -> List[str]:
return ["input_ids"]
@property
def outputs(self) -> List[str]:
common_outputs = ["text_embeds", "last_hidden_state"]
if self._normalized_config.output_hidden_states:
common_outputs.append("hidden_states")
return common_outputs
@register_in_tasks_manager("clip-text-model", *["feature-extraction"])
class CLIPTextNeuronConfig(CLIPTextWithProjectionNeuronConfig):
MODEL_TYPE = "clip-text-model"
@property
def outputs(self) -> List[str]:
common_outputs = ["last_hidden_state", "pooler_output"]
if self._normalized_config.output_hidden_states:
common_outputs.append("hidden_states")
return common_outputs
@register_in_tasks_manager("unet", *["semantic-segmentation"])
class UNetNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height")
MODEL_TYPE = "unet"
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="cross_attention_dim",
vocab_size="norm_num_groups",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
DummyTimestepInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
)
@property
def inputs(self) -> List[str]:
common_inputs = ["sample", "timestep", "encoder_hidden_states"]
# TODO : add text_image, image and image_embeds
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs.append("text_embeds")
common_inputs.append("time_ids")
if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs.append("timestep_cond")
return common_inputs
@property
def outputs(self) -> List[str]:
return ["sample"]
def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
# For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`.
# TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now)
if self.height == self.width:
self._normalized_config.image_size = self.height
else:
raise ValueError(
"You need to input the same value for `self.height({self.height})` and `self.width({self.width})`."
)
dummy_inputs = super().generate_dummy_inputs(**kwargs)
dummy_inputs["timestep"] = dummy_inputs["timestep"].float()
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
dummy_inputs["added_cond_kwargs"] = {
"text_embeds": dummy_inputs.pop("text_embeds"),
"time_ids": dummy_inputs.pop("time_ids"),
}
if return_tuple is True:
return tuple(dummy_inputs.values())
else:
return dummy_inputs
class ModelWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names
def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(input)} inputs are passed."
)
ordered_inputs = dict(zip(self.input_names, inputs))
added_cond_kwargs = {
"text_embeds": ordered_inputs.pop("text_embeds", None),
"time_ids": ordered_inputs.pop("time_ids", None),
}
sample = ordered_inputs.pop("sample", None)
timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],))
out_tuple = self.model(
sample=sample,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
**ordered_inputs,
)
return out_tuple
def check_model_inputs_order(self, model, dummy_inputs):
return self.ModelWrapper(model, list(dummy_inputs.keys()))
@property
def is_sdxl(self) -> bool:
return self._is_sdxl
@is_sdxl.setter
def is_sdxl(self, is_sdxl: bool):
self._is_sdxl = is_sdxl
@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"])
class VaeEncoderNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MODEL_TYPE = "vae-encoder"
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
num_channels="in_channels",
image_size="sample_size",
allow_new=True,
)
@property
def inputs(self) -> List[str]:
return ["sample"]
@property
def outputs(self) -> List[str]:
return ["latent_sample"]
def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
# For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`.
# TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now)
if self.height == self.width:
self._normalized_config.image_size = self.height
else:
raise ValueError(
"You need to input the same value for `self.height({self.height})` and `self.width({self.width})`."
)
dummy_inputs = super().generate_dummy_inputs(**kwargs)
if return_tuple is True:
return tuple(dummy_inputs.values())
else:
return dummy_inputs
@register_in_tasks_manager("vae-decoder", *["semantic-segmentation"])
class VaeDecoderNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MODEL_TYPE = "vae-decoder"
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
num_channels="latent_channels",
allow_new=True,
)
@property
def inputs(self) -> List[str]:
return ["latent_sample"]
@property
def outputs(self) -> List[str]:
return ["sample"]
def check_model_inputs_order(
self,
model: "VaeDecoder",
dummy_inputs: Dict[str, torch.Tensor],
**kwargs,
):
return super().check_model_inputs_order(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)
@register_in_tasks_manager("gpt2", "text-generation")
class GPT2NeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "gpt2.model.GPT2ForSampling"
@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
@register_in_tasks_manager("opt", "text-generation")
class OPTNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "opt.model.OPTForSampling"
@register_in_tasks_manager("bloom", "text-generation")
class BloomNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "bloom.model.BloomForSampling"