-
Notifications
You must be signed in to change notification settings - Fork 0
/
ordering.py
57 lines (52 loc) · 1.83 KB
/
ordering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
try:
from transformers import BartForSequenceOrdering
except:
print(
"Warning: If you try to use BartForSequenceOrdering, you are using the wring python env. Please make sure to use the good branch"
)
from ..model import Model
class OrderingModel(Model):
"""
Class for BART for the ordering model
"""
def __init__(
self,
name,
model_name,
tokenizer_name,
device,
quantization,
onnx,
onnx_convert_kwargs,
ordering_parameters={},
):
super().__init__(
name, BartForSequenceOrdering, model_name, tokenizer_name, device, quantization, onnx, onnx_convert_kwargs
)
self.ordering_parameters = ordering_parameters
def _predict(self, x):
x = x[0]
pt_batch = self.tokenizer(
[" </s> <s> ".join(sequences) + " </s> <s>" for sequences in x],
padding=True,
truncation=True,
max_length=self.tokenizer.max_len,
return_tensors="pt",
)
outputs = self.model.order(
input_ids=pt_batch["input_ids"].to(self.device),
attention_mask=pt_batch["attention_mask"].to(self.device),
**self.ordering_parameters,
)
for output, sequences in zip(outputs, x):
output.remove(max(output))
for i in range(len(sequences)):
if i not in output:
output.append(i)
while max(output) > len(sequences) - 1:
print(
f"INFO: Before second verification: sequences: {len(sequences)} - output: {len(output)} --- \n output:\n{output}"
)
output.remove(max(output))
assert len(output) == len(sequences), f"sequences: {sequences} - output: {output}"
return outputs