Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to make it compatible with TensorRT #23683

Merged
merged 2 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -420,7 +420,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -614,7 +613,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*

attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
uchuhimo marked this conversation as resolved.
Show resolved Hide resolved
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -453,7 +453,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -484,7 +483,7 @@ def forward(
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)

Expand Down Expand Up @@ -687,7 +686,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*

attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -754,7 +754,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*

attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down Expand Up @@ -1086,7 +1086,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -1117,7 +1116,7 @@ def forward(
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)

Expand Down