-
Notifications
You must be signed in to change notification settings - Fork 2
/
functions.py
106 lines (90 loc) · 4.4 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import random
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
import codecs
from sklearn.metrics import f1_score
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from process_data import *
# load model
def process_model_only(model_path, device):
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, return_dict=True, output_hidden_states=False)
model = model.to(device)
parallel_model = nn.DataParallel(model)
return model, parallel_model, tokenizer
# load model, process trigger information
def process_model_wth_trigger(model_path, trigger_words_list, device):
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, return_dict=True)
model = model.to(device)
parallel_model = nn.DataParallel(model)
trigger_inds_list = []
ori_norms_list = []
for trigger_word in trigger_words_list:
trigger_ind = int(tokenizer(trigger_word)['input_ids'][1])
trigger_inds_list.append(trigger_ind)
ori_norm = model.bert.embeddings.word_embeddings.weight[trigger_ind, :].view(1, -1).to(device).norm().item()
ori_norms_list.append(ori_norm)
return model, parallel_model, tokenizer, trigger_inds_list, ori_norms_list
# calculate binary acc.
def binary_accuracy(preds, y):
rounded_preds = torch.argmax(preds, dim=1)
correct = (rounded_preds == y).float()
acc_num = correct.sum().item()
acc = acc_num / len(correct)
return acc_num, acc
# evaluate test accuracy
def evaluate(model, tokenizer, eval_text_list, eval_label_list, batch_size, criterion, device):
epoch_loss = 0
epoch_acc_num = 0
model.eval()
total_eval_len = len(eval_text_list)
if total_eval_len % batch_size == 0:
NUM_EVAL_ITER = int(total_eval_len / batch_size)
else:
NUM_EVAL_ITER = int(total_eval_len / batch_size) + 1
with torch.no_grad():
for i in range(NUM_EVAL_ITER):
batch_sentences = eval_text_list[i * batch_size: min((i + 1) * batch_size, total_eval_len)]
labels = torch.from_numpy(
np.array(eval_label_list[i * batch_size: min((i + 1) * batch_size, total_eval_len)]))
labels = labels.type(torch.LongTensor).to(device)
batch = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt").to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
loss = criterion(outputs.logits, labels)
acc_num, acc = binary_accuracy(outputs.logits, labels)
epoch_loss += loss.item() * len(batch_sentences)
epoch_acc_num += acc_num
return epoch_loss / total_eval_len, epoch_acc_num / total_eval_len
# evaluate test macro F1
def evaluate_f1(model, tokenizer, eval_text_list, eval_label_list, batch_size, criterion, device):
epoch_loss = 0
model.eval()
total_eval_len = len(eval_text_list)
if total_eval_len % batch_size == 0:
NUM_EVAL_ITER = int(total_eval_len / batch_size)
else:
NUM_EVAL_ITER = int(total_eval_len / batch_size) + 1
with torch.no_grad():
predict_labels = []
true_labels = []
for i in range(NUM_EVAL_ITER):
batch_sentences = eval_text_list[i * batch_size: min((i + 1) * batch_size, total_eval_len)]
labels = torch.from_numpy(
np.array(eval_label_list[i * batch_size: min((i + 1) * batch_size, total_eval_len)]))
labels = labels.type(torch.LongTensor).to(device)
batch = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt").to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
loss = criterion(outputs.logits, labels)
epoch_loss += loss.item() * len(batch_sentences)
predict_labels = predict_labels + list(np.array(torch.argmax(outputs.logits, dim=1).cpu()))
true_labels = true_labels + list(np.array(labels.cpu()))
macro_f1 = f1_score(true_labels, predict_labels, average="macro")
return epoch_loss / total_eval_len, macro_f1