forked from SKRohit/Generating_Text_Summary_With_GPT2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_gpt2_summarizer.py
174 lines (150 loc) · 8.7 KB
/
train_gpt2_summarizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import argparse
from datetime import datetime
import json
import os
import pickle
import random
import time
import numpy as np
from pytorch_transformers import ConstantLRSchedule, GPT2Config, GPT2LMHeadModel,AdamW, GPT2Tokenizer, WarmupLinearSchedule
from tensorboardX import SummaryWriter
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tnrange, tqdm
from dataset import GPT21024Dataset
from utils import add_special_tokens, beam_search, generate_beam_sample, generate_sample, sample_seq, set_seed, top_k_top_p_filtering
def train(args, model, tokenizer, train_dataset, valid_dataset, ignore_index):
""" Trains GPT2 model and logs necessary details.
Args:
args: dict that contains all the necessary information passed by user while training
model: finetuned gpt/gpt2 model
tokenizer: GPT/GPT2 tokenizer
train_dataset: GPT21024Dataset object for training data
ignore_index: token not considered in loss calculation
"""
writer = SummaryWriter('./logs')
train_sampler = RandomSampler(train_dataset)
train_dl = DataLoader(train_dataset,sampler=train_sampler,batch_size=args.batch_size,num_workers=args.num_workers)
loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculation
optimizer = AdamW(model.parameters(),lr=args.lr)
scheduler = WarmupLinearSchedule(optimizer,100,80000)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = tnrange(int(args.num_train_epochs), desc="Epoch")
set_seed(args)
for _ in train_iterator:
epoch_iterator = tqdm(train_dl, desc="Training")
for step, batch in enumerate(epoch_iterator):
inputs, labels = torch.tensor(batch['article']), torch.tensor(batch['article'])
inputs = inputs.to(args.device)
labels = labels.to(args.device)
model.train()
logits = model(inputs)[0]
idx = batch['sum_idx'].item() # index of separator token
# only consider loss on reference summary just like seq2seq models
shift_logits = logits[..., idx:-1, :].contiguous()
shift_labels = labels[..., idx+1:].contiguous()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss/args.gradient_accumulation_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
writer.add_scalar('loss', (tr_loss - logging_loss)/args.gradient_accumulation_steps, global_step)
logging_loss = tr_loss
print("loss:", loss.item(), end='\n\n')
if (step + 1)/args.gradient_accumulation_steps == 1.0:
print('After 1st update: ', end='\n\n')
generate_sample(valid_dataset, tokenizer, num=2, eval_step=False)
if (step + 1) % (10*args.gradient_accumulation_steps) == 0:
results = evaluate(args, model, valid_dataset, ignore_index, global_step)
for key, value in results.items():
writer.add_scalar('eval_{}'.format(key), value, global_step)
print('After', global_step+1,'updates: ', end='\n\n')
generate_sample(valid_dataset, tokenizer, num=2, eval_step=True)
def evaluate(args, model, eval_dataset, ignore_index, global_step=None):
""" Returns perplexity score on validation dataset.
Args:
args: dict that contains all the necessary information passed by user while training
model: finetuned gpt/gpt2 model
eval_dataset: GPT21024Dataset object for validation data
global_step: no. of times gradients have backpropagated
ignore_index: token not considered in loss calculation
"""
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
eval_output_dir = args.output_dir
results = {}
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.batch_size)
loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculation
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
inputs, labels = torch.tensor(batch['article']).to(args.device), torch.tensor(batch['article']).to(args.device)
with torch.no_grad():
logits = model(inputs)[0]
idx = batch['sum_idx'].item() # index of separator token
# only consider loss on reference summary just like seq2seq models
shift_logits = logits[..., idx:-1, :].contiguous()
shift_labels = labels[..., idx+1:].contiguous()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
eval_loss += lm_loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {
"perplexity": perplexity
}
print("perplexity:", perplexity.item())
if global_step:
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
with open(output_eval_file, "a") as f:
for key in sorted(result.keys()):
f.write('\n\n')
f.write("time = %s, %s = %s, step = %s\n" % (datetime.now().strftime("%d/%m/%Y %H:%M:%S"), key, str(result[key]), str(global_step)))
return result
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, required=True, help="learning rate")
parser.add_argument("--seed",default=42, type=int, required=False, help="seed to replicate results")
parser.add_argument("--n_gpu",default=1, type=int, required=False, help="no of gpu available")
parser.add_argument("--gradient_accumulation_steps",default=32, type=int, required=True, help="gradient_accumulation_steps")
parser.add_argument("--batch_size",default=1, type=int, required=True, help="batch_size")
parser.add_argument("--num_workers",default=4, type=int, required=False, help="num of cpus available")
parser.add_argument("--device",default=torch.device('cuda'), required=False, help="torch.device object")
parser.add_argument("--num_train_epochs",default=5, type=int, required=True, help="no of epochs of training")
parser.add_argument("--output_dir",default=./output, type=str, required=True, help="path to save evaluation results")
parser.add_argument("--model_dir",default=./weights, type=str, required=True, help="path to save trained model")
parser.add_argument("--fp16",default=True, type=bool, required=False, help="whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument("--fp16_opt_level",default='O0', type=str, required=False, help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3'].")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="max gradient norm.")
parser.add_argument("--root_dir",default='./CNN/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default='./CNN/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args()
train_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='train',length=3000) #training on only 3000 datasets
valid_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='valid',length=500) #validation on only 500 datasets
tokenizer = add_special_tokens()
ignore_idx = tokenizer.pad_token_id
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.to(args.device)
start = time.time()
train(args, model, tokenizer, train_data, valid_data, ignore_index)
print('total time: ', (time.time()-start)/60, " minutes", end='\n\n')
print('Saving trained model...')
model_file = os.path.join(args['model_dir'], 'model_{}_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.bin'.format(args['fp16_opt_level'],3000,args['num_train_epochs']))
config_file = os.path.join(args['model_dir'], 'config_{}_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.json'.format(args['fp16_opt_level'],3000,args['num_train_epochs']))
torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)
if __name__ == '__main__':
main()