-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_cls.py
215 lines (189 loc) · 7.96 KB
/
run_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import random
import numpy as np
import torch
import wandb
import argparse
from pathlib import Path
from transformers import AutoTokenizer
from torch.utils.data import random_split
from torch.utils.data import ConcatDataset
from metrics import ClassificationEvaluator
from metrics import acc_f1
import ipdb
from datareader import GoldSuttonDataset
from datareader import ClassificationDataset
from datareader import text_to_batch_transformer
from datareader import NLI_LABELS
from trainer import TransformerClassificationTrainer
def enforce_reproducibility(seed=1000):
# Sets seed manually for both CPU and CUDA
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# For atomic operations there is currently
# no simple way to enforce determinism, as
# the order of parallel operations is not known.
# CUDNN
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# System based
random.seed(seed)
np.random.seed(seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--yu_et_al_data", help="Location of the Yu et al data", default=None, type=str)
parser.add_argument("--pet_data", help="Location of the PET soft labelled press data", default=None, type=str)
parser.add_argument("--test_data_loc", help="Location of the test data", required=True, type=str)
parser.add_argument("--model_name",
help="The name of the model being tested. Can be a directory for a local model",
required=True, type=str)
parser.add_argument("--model_dir", help="Top level directory to save the models", required=True, type=str)
parser.add_argument("--run_name", help="A name for this run", required=True, type=str)
parser.add_argument("--tag", help="A tag to give this run (for wandb)", required=True, type=str)
parser.add_argument("--n_gpu", help="The number of GPUs to use", type=int, default=0)
parser.add_argument("--temperature", help="The temperature to use for distillation loss", type=float, default=1.0)
parser.add_argument("--batch_size", help="The batch size", type=int, default=8)
parser.add_argument("--learning_rate", help="The learning rate", type=float, default=3e-5)
parser.add_argument("--weight_decay", help="Amount of weight decay", type=float, default=0.0)
parser.add_argument("--dropout_prob", help="The dropout probability", type=float, default=0.1)
parser.add_argument("--n_epochs", help="The number of epochs to run", type=int, default=2)
parser.add_argument("--seed", type=int, help="Random seed", default=1000)
parser.add_argument("--warmup_steps", help="The number of warmup steps", type=int, default=200)
parser.add_argument("--balance_class_weight", action="store_true", default=False, help="Whether or not to use balanced class weights")
args = parser.parse_args()
seed = args.seed
# Always first
enforce_reproducibility(seed)
lr = args.learning_rate
weight_decay = args.weight_decay
warmup_steps = args.warmup_steps
dropout_prob = args.dropout_prob
batch_size = args.batch_size
n_epochs = args.n_epochs
class_weights = 'balanced' if args.balance_class_weight else None
model_name = args.model_name
use_scheduler = True
num_labels = [4]
# if not multi_task:
# num_labels = num_labels[args.eval_task]
assert batch_size % args.n_gpu == 0, "Batch must be divisible by the number of GPUs used"
assert (args.yu_et_al_data != None or args.pet_data != None), "Need to specify some training data"
config = {
"epochs": n_epochs,
"learning_rate": lr,
"warmup": warmup_steps,
"weight_decay": weight_decay,
"batch_size": batch_size,
"model": model_name,
"seed": seed,
"use_scheduler": use_scheduler,
"balance_class_weight": args.balance_class_weight,
"temperature": args.temperature
}
# wandb initialization
run = wandb.init(
project="computational-science-journalism",
name=args.run_name,
config=config,
reinit=True,
tags=[args.tag]
)
wandb_path = Path(wandb.run.dir)
# See if CUDA available
device = torch.device("cpu")
if torch.cuda.is_available():
print("Training on GPU")
device = torch.device("cuda:0")
# Train the press model
data_loc = args.yu_et_al_data
model = model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_fn = text_to_batch_transformer
if args.pet_data is not None:
pet_dset = ClassificationDataset(args.pet_data, tokenizer, tokenizer_fn=tokenizer_fn, soft_labels=True)
# train_dset += [pet_press_dset]
# num_labels += [4]
train_dset = [pet_dset]
valid_dset = []
num_labels = [4]
class_weights = None
else:
dset = ClassificationDataset(data_loc, tokenizer, tokenizer_fn=tokenizer_fn)
train_size = int(len(dset) * 0.8)
val_size = len(dset) - train_size
subsets = random_split(dset, [train_size, val_size])
train_dset = [subsets[0]]
valid_dset = [subsets[1]]
trainer = TransformerClassificationTrainer(
model,
device,
num_labels=num_labels,
tokenizer=tokenizer,
multi_gpu=args.n_gpu > 1
)
# Create a new directory to save the model
model_dir = f"{args.model_dir}/{wandb.run.id}"
# Create save directory for model
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# Train it
trainer.train(
train_dset,
valid_dset,
weight_decay=weight_decay,
model_file=f"{model_dir}/model.pth",
class_weights=class_weights,
metric_name='F1',
logger=wandb,
lr=lr,
warmup_steps=warmup_steps,
n_epochs=n_epochs,
batch_size=batch_size,
use_scheduler=use_scheduler,
eval_averaging=['macro'],
temperature=args.temperature
)
# Predict on press and abstract text, compare to get final exaggeration label
# If we get actual test data
test_dset = GoldSuttonDataset(args.test_data_loc, tokenizer, tokenizer, tokenizer_fn=tokenizer_fn)
validation_evaluator = ClassificationEvaluator(
test_dset,
device,
num_labels=4,
averaging='macro',
pad_token_id=tokenizer.pad_token_id,
multi_gpu=args.n_gpu > 1
)
test_dset.mode = GoldSuttonDataset.CLS_PRESS
(labels_all_press, logits_all_press, losses_all, preds_all_press) = validation_evaluator.predict(trainer.model)
test_dset.mode = GoldSuttonDataset.CLS_ABSTRACT
(labels_all_abstract, logits_all_abstract, losses_all, preds_all_abstract) = validation_evaluator.predict(trainer.model)
# Get the final scores
preds_nli = []
for pr,ab in zip(preds_all_press, preds_all_abstract):
if ab < pr:
preds_nli.append(NLI_LABELS['exaggerates'])
elif ab > pr:
preds_nli.append(NLI_LABELS['downplays'])
else:
preds_nli.append(NLI_LABELS['same'])
test_dset.mode = GoldSuttonDataset.NLI
gold_labels = test_dset.getLabels()
# Press
acc, P, R, F1 = acc_f1(np.array(preds_all_press).reshape(-1), np.array(labels_all_press).reshape(-1), 'macro')
wandb.run.summary['press-acc'] = acc
wandb.run.summary['press-P'] = P
wandb.run.summary['press-R'] = R
wandb.run.summary['press-F1'] = F1
# Abstract
acc, P, R, F1 = acc_f1(np.array(preds_all_abstract).reshape(-1), np.array(labels_all_abstract).reshape(-1), 'macro')
wandb.run.summary['abstract-acc'] = acc
wandb.run.summary['abstract-P'] = P
wandb.run.summary['abstract-R'] = R
wandb.run.summary['abstract-F1'] = F1
# NLI
acc,P,R,F1 = acc_f1(np.array(preds_nli).reshape(-1), np.array(gold_labels).reshape(-1), 'macro')
wandb.run.summary['NLI-acc'] = acc
wandb.run.summary['NLI-P'] = P
wandb.run.summary['NLI-R'] = R
wandb.run.summary['NLI-F1'] = F1