-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add QAT support to more models (#29)
* first version of QDQ monkey patching * add Albert, Electra and Distilbert QAT support * add QDQDeberta V1 * fix distilbert * add ast patch add quant onnx export * simplify quantization process * fix qdq deberta * quantization refactoring * add documentation add quantization tests add deberta v2 * add quant of layernorm refactor ast modif add tests * add operator name in quantizer name update notebook * update notebook * update notebook
- Loading branch information
1 parent
2b369c1
commit 404c5ee
Showing
19 changed files
with
1,572 additions
and
2,883 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.2.1 | ||
0.3.0 |
1,989 changes: 719 additions & 1,270 deletions
1,989
demo/quantization_end_to_end.ipynb → ...uantization/quantization_end_to_end.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021, Lefebvre Sarrut Services | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from transformer_deploy.QDQModels.ast_module_patch import PatchModule | ||
|
||
|
||
qdq_albert_mapping: PatchModule = PatchModule( | ||
module="transformers.models.albert.modeling_albert", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021, Lefebvre Sarrut Services | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from transformer_deploy.QDQModels.ast_module_patch import PatchModule | ||
|
||
|
||
qdq_bert_mapping: PatchModule = PatchModule( | ||
module="transformers.models.bert.modeling_bert", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2021, Lefebvre Sarrut Services | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
from transformer_deploy.QDQModels.ast_module_patch import PatchModule | ||
|
||
|
||
# in class DebertaEncoder(nn.Module): | ||
def get_attention_mask(self, attention_mask): | ||
if attention_mask.dim() <= 2: | ||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | ||
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) | ||
# unecessary conversion, byte == unsigned integer -> not supported by TensorRT | ||
# attention_mask = attention_mask.byte() | ||
elif attention_mask.dim() == 3: | ||
attention_mask = attention_mask.unsqueeze(1) | ||
|
||
return attention_mask | ||
|
||
|
||
# in class XSoftmax(torch.autograd.Function): | ||
# @staticmethod | ||
def symbolic(g, self, mask, dim): | ||
import torch.onnx.symbolic_helper as sym_help | ||
from torch.onnx.symbolic_opset9 import masked_fill, softmax | ||
|
||
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) | ||
# r_mask = g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value) | ||
# replace Byte by Char to get signed numbers | ||
r_mask = g.op( | ||
"Cast", | ||
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), | ||
to_i=sym_help.cast_pytorch_to_onnx["Char"], | ||
) | ||
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf")))) | ||
output = softmax(g, output, dim) | ||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int8))) | ||
|
||
|
||
qdq_deberta_mapping: PatchModule = PatchModule( | ||
module="transformers.models.deberta.modeling_deberta", | ||
monkey_patch={ | ||
"XSoftmax.symbolic": (symbolic, "symbolic"), | ||
"DebertaEncoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), | ||
}, | ||
) | ||
|
||
|
||
def toto(): | ||
print("1") | ||
|
||
|
||
qdq_deberta_v2_mapping: PatchModule = PatchModule( | ||
module="transformers.models.deberta_v2.modeling_deberta_v2", | ||
monkey_patch={ | ||
"XSoftmax.symbolic": (toto, "toto"), | ||
"DebertaV2Encoder.get_attention_mask": (get_attention_mask, "get_attention_mask"), | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021, Lefebvre Sarrut Services | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from transformer_deploy.QDQModels.ast_module_patch import PatchModule | ||
|
||
|
||
qdq_distilbert_mapping: PatchModule = PatchModule( | ||
module="transformers.models.distilbert.modeling_distilbert", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright 2021, Lefebvre Sarrut Services | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from transformer_deploy.QDQModels.ast_module_patch import PatchModule | ||
|
||
|
||
qdq_electra_mapping: PatchModule = PatchModule( | ||
module="transformers.models.electra.modeling_electra", | ||
) |
Oops, something went wrong.