This repository has been archived by the owner on Aug 25, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathserve_http.py
113 lines (87 loc) · 3.11 KB
/
serve_http.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Script for serving.
"""
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from flask import Flask, request
from utils.data import TextDataset
from utils.model import load_tokenizer, load_model
from utils.xai import lig_attribute
MAX_LEN = 256
BATCH_SIZE = 8
TOKENIZER = load_tokenizer()
DEVICE = torch.device("cpu")
MODEL = load_model(2, "/artefact/finetuned_model.bin", DEVICE)
MODEL.eval()
# pylint: disable=too-many-locals
def predict(request_json):
"""Predict function."""
sentences = request_json["sentences"]
test_data = TextDataset(sentences, TOKENIZER, MAX_LEN)
test_loader = DataLoader(
test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
y_prob = list()
for data in test_loader:
ids = data["ids"].to(DEVICE)
mask = data["mask"].to(DEVICE)
with torch.no_grad():
logits = MODEL(ids, attention_mask=mask)[0]
probs = F.softmax(logits, dim=1)
y_prob.extend(probs[:, 1].cpu().numpy().tolist())
return y_prob
def nlp_xai(request_json):
"""Perform XAI."""
def forward_func(input_ids, attention_mask):
outputs = MODEL(input_ids, attention_mask=attention_mask)
pred = outputs[0]
return pred[:, 1]
def tokenize(sentence):
ref_token_id = TOKENIZER.pad_token_id
sep_token_id = TOKENIZER.sep_token_id
cls_token_id = TOKENIZER.cls_token_id
inputs = TOKENIZER.encode_plus(
sentence,
None,
add_special_tokens=True,
truncation=True,
max_length=MAX_LEN,
padding="max_length",
return_token_type_ids=True,
)
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long)
ref_input_ids = torch.tensor(
[[x if x == cls_token_id or x == sep_token_id else ref_token_id for x in inputs["input_ids"]]],
dtype=torch.long,
)
mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long)
sent_len = sum(inputs["attention_mask"])
tokens = TOKENIZER.convert_ids_to_tokens(inputs["input_ids"][1:sent_len-1])
return input_ids, ref_input_ids, mask, sent_len, tokens
attributes = list()
for sentence in request_json["sentences"]:
input_ids, ref_input_ids, mask, sent_len, tokens = tokenize(sentence)
attributions, delta = lig_attribute(
forward_func, MODEL.distilbert.embeddings, input_ids, ref_input_ids, mask)
attributes.append({
"attributions": attributions[1:sent_len-1].tolist(),
"delta": delta[0],
"tokens": tokens,
})
return {"attributes": attributes}
# pylint: disable=invalid-name
app = Flask(__name__)
@app.route("/", methods=["POST"])
def get_prob():
"""Returns probability."""
y_prob = predict(request.json)
output = {"y_prob": y_prob}
if request.json["bool_xai"] == 1:
attributes = nlp_xai(request.json)
output.update(attributes)
return output
def main():
"""Starts the Http server"""
app.run()
if __name__ == "__main__":
main()