Skip to content

Commit

Permalink
Support ONNX export for causal LM sequence classifiers (#27450)
Browse files Browse the repository at this point in the history
support onnx for causal lm sequence classification
  • Loading branch information
dwyatte authored Nov 16, 2023
1 parent 06343b0 commit 1394e08
Show file tree
Hide file tree
Showing 14 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/openai/modeling_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/transfo_xl/modeling_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
logits.device
)
else:
Expand Down

0 comments on commit 1394e08

Please sign in to comment.